Skip to main content

entrenar/train/transformer_trainer/
gpu_grad_accumulator.rs

1// ALB-091: GPU-resident gradient accumulation.
2//
3// Replaces CPU-side PerBlockGradientAccumulator when running with ga > 1.
4// Eliminates the 24 × ga stream.synchronize() + D2H transfers per optimizer step
5// that caused a 2.6x throughput regression (7.6K → 2.9K tok/s for 350M model).
6//
7// All accumulation happens on GPU via in-place element-wise add (ResidualAddKernel).
8// Only ONE stream sync per optimizer step (not per micro-batch per block).
9//
10// VRAM cost: ~1.55 GB for 350M model (24 blocks × ~60 MB/block + 128 MB LM head).
11
12#[cfg(feature = "cuda")]
13use std::sync::Arc;
14
15#[cfg(feature = "cuda")]
16use trueno_gpu::driver::{CudaContext, CudaStream, GpuBuffer};
17
18#[cfg(feature = "cuda")]
19use crate::autograd::cuda_forward::inplace_add_gpu;
20#[cfg(feature = "cuda")]
21use crate::transformer::cuda_block::CudaGradWorkspace;
22#[cfg(feature = "cuda")]
23use crate::transformer::TransformerConfig;
24
25#[cfg(feature = "cuda")]
26use super::grad_accumulator::BLOCK_GRAD_COMPONENTS;
27
28#[cfg(feature = "cuda")]
29fn gpu_err(e: impl std::fmt::Debug) -> crate::error::Error {
30    crate::error::Error::ConfigError(format!("GPU error: {e:?}"))
31}
32
33/// GPU-resident gradient accumulation buffers for one transformer block.
34#[cfg(feature = "cuda")]
35pub struct GpuBlockGradAccum {
36    /// 9 GPU buffers matching CudaGradWorkspace components:
37    /// [w_q, w_k, w_v, w_o, gate, up, down, input_norm, post_attn_norm]
38    pub components: Vec<GpuBuffer<f32>>,
39    /// Pre-allocated host zero buffer (sized to largest component).
40    zero_host: Vec<f32>,
41}
42
43#[cfg(feature = "cuda")]
44impl GpuBlockGradAccum {
45    fn new(ctx: &Arc<CudaContext>, sizes: &[usize; BLOCK_GRAD_COMPONENTS]) -> crate::Result<Self> {
46        let mut components = Vec::with_capacity(BLOCK_GRAD_COMPONENTS);
47        for &sz in sizes {
48            components.push(GpuBuffer::new(ctx, sz).map_err(gpu_err)?);
49        }
50        let max_size = sizes.iter().copied().max().unwrap_or(0);
51        Ok(Self { components, zero_host: vec![0.0f32; max_size] })
52    }
53
54    fn zero_all(&mut self) -> crate::Result<()> {
55        for buf in &mut self.components {
56            let n = buf.len();
57            buf.copy_from_host(&self.zero_host[..n]).map_err(gpu_err)?;
58        }
59        Ok(())
60    }
61}
62
63/// GPU-resident gradient accumulator for the full model.
64///
65/// Replaces `PerBlockGradientAccumulator` (CPU-side) when CUDA is available.
66/// All gradient accumulation happens on GPU — zero D2H during micro-batch loop.
67#[cfg(feature = "cuda")]
68pub struct GpuGradientAccumulator {
69    /// Per-block gradient accumulation buffers
70    pub block_accums: Vec<GpuBlockGradAccum>,
71    /// LM head gradient accumulator [vocab_size × hidden_size]
72    pub lm_head_accum: GpuBuffer<f32>,
73    /// Final norm gradient accumulator [hidden_size]
74    pub final_norm_accum: GpuBuffer<f32>,
75    /// Embedding gradient accumulator (CPU — embedding is CPU-side)
76    pub embedding_accum: Vec<f32>,
77    /// Number of accumulated micro-batches
78    pub accumulated_count: usize,
79    /// Component sizes per block (for iteration)
80    pub block_component_sizes: [usize; BLOCK_GRAD_COMPONENTS],
81    /// Pre-allocated host zero buffer for LM head zero
82    lm_head_zero: Vec<f32>,
83    /// Pre-allocated host zero buffer for final norm zero
84    final_norm_zero: Vec<f32>,
85}
86
87#[cfg(feature = "cuda")]
88impl GpuGradientAccumulator {
89    /// Allocate GPU accumulation buffers matching the model architecture.
90    ///
91    /// VRAM cost: ~1.55 GB for 350M model (H=1024, I=4096, L=24, V=32768).
92    pub fn new(ctx: &Arc<CudaContext>, config: &TransformerConfig) -> crate::Result<Self> {
93        let h = config.hidden_size;
94        let kv = config.num_kv_heads * config.head_dim();
95        let i = config.intermediate_size;
96        let v = config.vocab_size;
97
98        let sizes =
99            super::grad_accumulator::PerBlockGradientAccumulator::compute_block_sizes(h, kv, i);
100
101        let mut block_accums = Vec::with_capacity(config.num_hidden_layers);
102        for _ in 0..config.num_hidden_layers {
103            block_accums.push(GpuBlockGradAccum::new(ctx, &sizes)?);
104        }
105
106        let lm_head_accum = GpuBuffer::new(ctx, v * h).map_err(gpu_err)?;
107        let final_norm_accum = GpuBuffer::new(ctx, h).map_err(gpu_err)?;
108        let embedding_accum = vec![0.0f32; v * h];
109
110        let total_vram_mb = (block_accums.len() as f64 * sizes.iter().sum::<usize>() as f64
111            + (v * h) as f64
112            + h as f64)
113            * 4.0
114            / (1024.0 * 1024.0);
115
116        eprintln!(
117            "  GPU gradient accumulation: {} blocks, {:.1} MB VRAM",
118            config.num_hidden_layers, total_vram_mb,
119        );
120
121        let mut accum = Self {
122            block_accums,
123            lm_head_accum,
124            final_norm_accum,
125            embedding_accum,
126            accumulated_count: 0,
127            block_component_sizes: sizes,
128            lm_head_zero: vec![0.0f32; v * h],
129            final_norm_zero: vec![0.0f32; h],
130        };
131
132        // CRITICAL: GpuBuffer::new() does NOT zero VRAM (cuMemAlloc returns
133        // uninitialized memory). Without this, the first accumulation window
134        // adds real gradients to garbage → corrupted first optimizer step.
135        accum.zero_all()?;
136
137        Ok(accum)
138    }
139
140    /// Accumulate workspace gradients for a single block into GPU accum buffers.
141    ///
142    /// Uses in-place GPU add (ResidualAddKernel): accum[i] += workspace[i].
143    /// No stream synchronization — fully asynchronous on the CUDA stream.
144    pub fn accumulate_block(
145        &mut self,
146        workspace: &CudaGradWorkspace,
147        block_idx: usize,
148        stream: &CudaStream,
149    ) -> crate::Result<()> {
150        let accum = &mut self.block_accums[block_idx];
151        let ws_bufs = workspace_buffers(workspace);
152
153        for (comp_idx, (accum_buf, ws_buf)) in
154            accum.components.iter_mut().zip(ws_bufs.iter()).enumerate()
155        {
156            let n = self.block_component_sizes[comp_idx] as u32;
157            inplace_add_gpu(accum_buf, ws_buf, n, stream).map_err(gpu_err)?;
158        }
159        Ok(())
160    }
161
162    /// Accumulate non-block gradients (LM head + final norm) into GPU accum buffers.
163    pub fn accumulate_nonblock(
164        &mut self,
165        lm_head_grad: &GpuBuffer<f32>,
166        final_norm_grad: &GpuBuffer<f32>,
167        stream: &CudaStream,
168    ) -> crate::Result<()> {
169        inplace_add_gpu(&mut self.lm_head_accum, lm_head_grad, lm_head_grad.len() as u32, stream)
170            .map_err(gpu_err)?;
171        inplace_add_gpu(
172            &mut self.final_norm_accum,
173            final_norm_grad,
174            final_norm_grad.len() as u32,
175            stream,
176        )
177        .map_err(gpu_err)?;
178        Ok(())
179    }
180
181    /// Copy accumulated gradients back to workspace for optimizer step.
182    ///
183    /// Uses synchronous D2D copy. Must be called AFTER stream.synchronize().
184    pub fn upload_to_workspace(
185        &self,
186        workspace: &mut CudaGradWorkspace,
187        block_idx: usize,
188    ) -> crate::Result<()> {
189        let accum = &self.block_accums[block_idx];
190        let ws_bufs = workspace_buffers_mut(workspace);
191
192        for (ws_buf, accum_buf) in ws_bufs.into_iter().zip(accum.components.iter()) {
193            ws_buf.copy_from_buffer(accum_buf).map_err(gpu_err)?;
194        }
195        Ok(())
196    }
197
198    /// Copy accumulated non-block gradients to the training buffers.
199    pub fn upload_nonblock(
200        &self,
201        lm_head_grad: &mut GpuBuffer<f32>,
202        final_norm_grad: &mut GpuBuffer<f32>,
203    ) -> crate::Result<()> {
204        lm_head_grad.copy_from_buffer(&self.lm_head_accum).map_err(gpu_err)?;
205        final_norm_grad.copy_from_buffer(&self.final_norm_accum).map_err(gpu_err)?;
206        Ok(())
207    }
208
209    /// Zero all accumulation buffers (call at start of each optimizer step).
210    ///
211    /// Uses H2D copy from pre-allocated zero buffers. Called once per optimizer step
212    /// (not per micro-batch), so the H2D cost is negligible compared to savings.
213    pub fn zero_all(&mut self) -> crate::Result<()> {
214        for block in &mut self.block_accums {
215            block.zero_all()?;
216        }
217        self.lm_head_accum.copy_from_host(&self.lm_head_zero).map_err(gpu_err)?;
218        self.final_norm_accum.copy_from_host(&self.final_norm_zero).map_err(gpu_err)?;
219        self.embedding_accum.iter_mut().for_each(|x| *x = 0.0);
220        self.accumulated_count = 0;
221        Ok(())
222    }
223}
224
225/// Get references to workspace gradient buffers in component order.
226#[cfg(feature = "cuda")]
227fn workspace_buffers(ws: &CudaGradWorkspace) -> [&GpuBuffer<f32>; BLOCK_GRAD_COMPONENTS] {
228    [
229        &ws.grad_w_q,
230        &ws.grad_w_k,
231        &ws.grad_w_v,
232        &ws.grad_w_o,
233        &ws.grad_gate,
234        &ws.grad_up,
235        &ws.grad_down,
236        &ws.grad_input_norm,
237        &ws.grad_post_attn_norm,
238    ]
239}
240
241/// Get mutable references to workspace gradient buffers in component order.
242#[cfg(feature = "cuda")]
243fn workspace_buffers_mut(
244    ws: &mut CudaGradWorkspace,
245) -> [&mut GpuBuffer<f32>; BLOCK_GRAD_COMPONENTS] {
246    [
247        &mut ws.grad_w_q,
248        &mut ws.grad_w_k,
249        &mut ws.grad_w_v,
250        &mut ws.grad_w_o,
251        &mut ws.grad_gate,
252        &mut ws.grad_up,
253        &mut ws.grad_down,
254        &mut ws.grad_input_norm,
255        &mut ws.grad_post_attn_norm,
256    ]
257}