entrenar/train/transformer_trainer/
gpu_grad_accumulator.rs1#[cfg(feature = "cuda")]
13use std::sync::Arc;
14
15#[cfg(feature = "cuda")]
16use trueno_gpu::driver::{CudaContext, CudaStream, GpuBuffer};
17
18#[cfg(feature = "cuda")]
19use crate::autograd::cuda_forward::inplace_add_gpu;
20#[cfg(feature = "cuda")]
21use crate::transformer::cuda_block::CudaGradWorkspace;
22#[cfg(feature = "cuda")]
23use crate::transformer::TransformerConfig;
24
25#[cfg(feature = "cuda")]
26use super::grad_accumulator::BLOCK_GRAD_COMPONENTS;
27
28#[cfg(feature = "cuda")]
29fn gpu_err(e: impl std::fmt::Debug) -> crate::error::Error {
30 crate::error::Error::ConfigError(format!("GPU error: {e:?}"))
31}
32
33#[cfg(feature = "cuda")]
35pub struct GpuBlockGradAccum {
36 pub components: Vec<GpuBuffer<f32>>,
39 zero_host: Vec<f32>,
41}
42
43#[cfg(feature = "cuda")]
44impl GpuBlockGradAccum {
45 fn new(ctx: &Arc<CudaContext>, sizes: &[usize; BLOCK_GRAD_COMPONENTS]) -> crate::Result<Self> {
46 let mut components = Vec::with_capacity(BLOCK_GRAD_COMPONENTS);
47 for &sz in sizes {
48 components.push(GpuBuffer::new(ctx, sz).map_err(gpu_err)?);
49 }
50 let max_size = sizes.iter().copied().max().unwrap_or(0);
51 Ok(Self { components, zero_host: vec![0.0f32; max_size] })
52 }
53
54 fn zero_all(&mut self) -> crate::Result<()> {
55 for buf in &mut self.components {
56 let n = buf.len();
57 buf.copy_from_host(&self.zero_host[..n]).map_err(gpu_err)?;
58 }
59 Ok(())
60 }
61}
62
63#[cfg(feature = "cuda")]
68pub struct GpuGradientAccumulator {
69 pub block_accums: Vec<GpuBlockGradAccum>,
71 pub lm_head_accum: GpuBuffer<f32>,
73 pub final_norm_accum: GpuBuffer<f32>,
75 pub embedding_accum: Vec<f32>,
77 pub accumulated_count: usize,
79 pub block_component_sizes: [usize; BLOCK_GRAD_COMPONENTS],
81 lm_head_zero: Vec<f32>,
83 final_norm_zero: Vec<f32>,
85}
86
87#[cfg(feature = "cuda")]
88impl GpuGradientAccumulator {
89 pub fn new(ctx: &Arc<CudaContext>, config: &TransformerConfig) -> crate::Result<Self> {
93 let h = config.hidden_size;
94 let kv = config.num_kv_heads * config.head_dim();
95 let i = config.intermediate_size;
96 let v = config.vocab_size;
97
98 let sizes =
99 super::grad_accumulator::PerBlockGradientAccumulator::compute_block_sizes(h, kv, i);
100
101 let mut block_accums = Vec::with_capacity(config.num_hidden_layers);
102 for _ in 0..config.num_hidden_layers {
103 block_accums.push(GpuBlockGradAccum::new(ctx, &sizes)?);
104 }
105
106 let lm_head_accum = GpuBuffer::new(ctx, v * h).map_err(gpu_err)?;
107 let final_norm_accum = GpuBuffer::new(ctx, h).map_err(gpu_err)?;
108 let embedding_accum = vec![0.0f32; v * h];
109
110 let total_vram_mb = (block_accums.len() as f64 * sizes.iter().sum::<usize>() as f64
111 + (v * h) as f64
112 + h as f64)
113 * 4.0
114 / (1024.0 * 1024.0);
115
116 eprintln!(
117 " GPU gradient accumulation: {} blocks, {:.1} MB VRAM",
118 config.num_hidden_layers, total_vram_mb,
119 );
120
121 let mut accum = Self {
122 block_accums,
123 lm_head_accum,
124 final_norm_accum,
125 embedding_accum,
126 accumulated_count: 0,
127 block_component_sizes: sizes,
128 lm_head_zero: vec![0.0f32; v * h],
129 final_norm_zero: vec![0.0f32; h],
130 };
131
132 accum.zero_all()?;
136
137 Ok(accum)
138 }
139
140 pub fn accumulate_block(
145 &mut self,
146 workspace: &CudaGradWorkspace,
147 block_idx: usize,
148 stream: &CudaStream,
149 ) -> crate::Result<()> {
150 let accum = &mut self.block_accums[block_idx];
151 let ws_bufs = workspace_buffers(workspace);
152
153 for (comp_idx, (accum_buf, ws_buf)) in
154 accum.components.iter_mut().zip(ws_bufs.iter()).enumerate()
155 {
156 let n = self.block_component_sizes[comp_idx] as u32;
157 inplace_add_gpu(accum_buf, ws_buf, n, stream).map_err(gpu_err)?;
158 }
159 Ok(())
160 }
161
162 pub fn accumulate_nonblock(
164 &mut self,
165 lm_head_grad: &GpuBuffer<f32>,
166 final_norm_grad: &GpuBuffer<f32>,
167 stream: &CudaStream,
168 ) -> crate::Result<()> {
169 inplace_add_gpu(&mut self.lm_head_accum, lm_head_grad, lm_head_grad.len() as u32, stream)
170 .map_err(gpu_err)?;
171 inplace_add_gpu(
172 &mut self.final_norm_accum,
173 final_norm_grad,
174 final_norm_grad.len() as u32,
175 stream,
176 )
177 .map_err(gpu_err)?;
178 Ok(())
179 }
180
181 pub fn upload_to_workspace(
185 &self,
186 workspace: &mut CudaGradWorkspace,
187 block_idx: usize,
188 ) -> crate::Result<()> {
189 let accum = &self.block_accums[block_idx];
190 let ws_bufs = workspace_buffers_mut(workspace);
191
192 for (ws_buf, accum_buf) in ws_bufs.into_iter().zip(accum.components.iter()) {
193 ws_buf.copy_from_buffer(accum_buf).map_err(gpu_err)?;
194 }
195 Ok(())
196 }
197
198 pub fn upload_nonblock(
200 &self,
201 lm_head_grad: &mut GpuBuffer<f32>,
202 final_norm_grad: &mut GpuBuffer<f32>,
203 ) -> crate::Result<()> {
204 lm_head_grad.copy_from_buffer(&self.lm_head_accum).map_err(gpu_err)?;
205 final_norm_grad.copy_from_buffer(&self.final_norm_accum).map_err(gpu_err)?;
206 Ok(())
207 }
208
209 pub fn zero_all(&mut self) -> crate::Result<()> {
214 for block in &mut self.block_accums {
215 block.zero_all()?;
216 }
217 self.lm_head_accum.copy_from_host(&self.lm_head_zero).map_err(gpu_err)?;
218 self.final_norm_accum.copy_from_host(&self.final_norm_zero).map_err(gpu_err)?;
219 self.embedding_accum.iter_mut().for_each(|x| *x = 0.0);
220 self.accumulated_count = 0;
221 Ok(())
222 }
223}
224
225#[cfg(feature = "cuda")]
227fn workspace_buffers(ws: &CudaGradWorkspace) -> [&GpuBuffer<f32>; BLOCK_GRAD_COMPONENTS] {
228 [
229 &ws.grad_w_q,
230 &ws.grad_w_k,
231 &ws.grad_w_v,
232 &ws.grad_w_o,
233 &ws.grad_gate,
234 &ws.grad_up,
235 &ws.grad_down,
236 &ws.grad_input_norm,
237 &ws.grad_post_attn_norm,
238 ]
239}
240
241#[cfg(feature = "cuda")]
243fn workspace_buffers_mut(
244 ws: &mut CudaGradWorkspace,
245) -> [&mut GpuBuffer<f32>; BLOCK_GRAD_COMPONENTS] {
246 [
247 &mut ws.grad_w_q,
248 &mut ws.grad_w_k,
249 &mut ws.grad_w_v,
250 &mut ws.grad_w_o,
251 &mut ws.grad_gate,
252 &mut ws.grad_up,
253 &mut ws.grad_down,
254 &mut ws.grad_input_norm,
255 &mut ws.grad_post_attn_norm,
256 ]
257}