Skip to main content

entrenar/train/transformer_trainer/
grad_accumulator.rs

1#![allow(dead_code)]
2// Per-block gradient accumulation for distributed data-parallel pretraining.
3//
4// The `CudaTransformerTrainer` uses a shared `CudaGradWorkspace` that is
5// overwritten for each block during backward. For DDP, we need to:
6//
7// 1. After block[i]'s backward, copy workspace gradients into `block_grads[i]`
8// 2. AllReduce `block_grads[i]` across workers
9// 3. Run optimizer_step for block[i] with averaged gradients
10//
11// This module provides CPU-side accumulation buffers that hold the per-block
12// gradients after they are downloaded from GPU. These buffers are what get
13// sent over the wire for AllReduce.
14//
15// # Contract
16//
17// C-DDP-001: Per-block gradient buffers match CudaGradWorkspace component sizes.
18
19/// Number of gradient components per transformer block.
20/// Matches CudaGradWorkspace: w_q, w_k, w_v, w_o, gate, up, down, input_norm, post_attn_norm
21pub const BLOCK_GRAD_COMPONENTS: usize = 9;
22
23/// Component indices for transformer block gradients.
24pub mod component {
25    pub const W_Q: usize = 0;
26    pub const W_K: usize = 1;
27    pub const W_V: usize = 2;
28    pub const W_O: usize = 3;
29    pub const GATE: usize = 4;
30    pub const UP: usize = 5;
31    pub const DOWN: usize = 6;
32    pub const INPUT_NORM: usize = 7;
33    pub const POST_ATTN_NORM: usize = 8;
34}
35
36/// Non-block component IDs (for wire protocol).
37pub mod non_block {
38    pub const LM_HEAD: u8 = 0;
39    pub const FINAL_NORM: u8 = 1;
40    pub const EMBEDDING: u8 = 2;
41}
42
43/// Gradient set for a single transformer block (CPU-side).
44///
45/// Contains 9 flattened f32 gradient vectors, one per CudaGradWorkspace component.
46/// These are downloaded from GPU after backward and before AllReduce.
47#[derive(Debug, Clone)]
48pub struct BlockGradientSet {
49    /// Gradient components, indexed by `component::*` constants
50    pub components: Vec<Vec<f32>>,
51}
52
53impl BlockGradientSet {
54    /// Create a new zeroed gradient set with the given component sizes.
55    ///
56    /// # Arguments
57    /// * `sizes` - Element count for each of the 9 components
58    pub fn zeroed(sizes: &[usize; BLOCK_GRAD_COMPONENTS]) -> Self {
59        let components = sizes.iter().map(|&sz| vec![0.0f32; sz]).collect();
60        Self { components }
61    }
62
63    /// Total number of f32 elements across all components.
64    pub fn total_elements(&self) -> usize {
65        self.components.iter().map(Vec::len).sum()
66    }
67
68    /// Get component sizes as u32 (for wire protocol).
69    pub fn component_sizes_u32(&self) -> Vec<u32> {
70        self.components.iter().map(|c| c.len() as u32).collect()
71    }
72
73    /// Flatten all components into a single contiguous Vec<f32>.
74    pub fn flatten(&self) -> Vec<f32> {
75        let total = self.total_elements();
76        let mut flat = Vec::with_capacity(total);
77        for comp in &self.components {
78            flat.extend_from_slice(comp);
79        }
80        flat
81    }
82
83    /// Reconstruct from a flat gradient vector and component sizes.
84    ///
85    /// # Panics
86    /// Panics if `flat.len() != sum(sizes)`.
87    pub fn from_flat(flat: &[f32], sizes: &[u32]) -> Self {
88        let total: usize = sizes.iter().map(|&s| s as usize).sum();
89        assert_eq!(flat.len(), total, "flat gradient length mismatch");
90        let mut components = Vec::with_capacity(sizes.len());
91        let mut offset = 0;
92        for &sz in sizes {
93            let sz = sz as usize;
94            components.push(flat[offset..offset + sz].to_vec());
95            offset += sz;
96        }
97        Self { components }
98    }
99
100    /// Zero all gradient components (reuse buffers).
101    pub fn zero(&mut self) {
102        for comp in &mut self.components {
103            for x in comp.iter_mut() {
104                *x = 0.0;
105            }
106        }
107    }
108
109    /// Element-wise add another gradient set into this one.
110    ///
111    /// # Panics
112    /// Panics if component sizes don't match.
113    pub fn accumulate(&mut self, other: &BlockGradientSet) {
114        assert_eq!(self.components.len(), other.components.len());
115        for (dst, src) in self.components.iter_mut().zip(&other.components) {
116            assert_eq!(dst.len(), src.len(), "component size mismatch");
117            for (d, s) in dst.iter_mut().zip(src) {
118                *d += s;
119            }
120        }
121    }
122
123    /// Divide all gradient elements by a scalar (for averaging).
124    pub fn scale(&mut self, divisor: f32) {
125        let inv = 1.0 / divisor;
126        for comp in &mut self.components {
127            for x in comp.iter_mut() {
128                *x *= inv;
129            }
130        }
131    }
132
133    /// Check if any element is NaN or Inf (Jidoka safety check).
134    pub fn has_non_finite(&self) -> bool {
135        self.components.iter().any(|comp| comp.iter().any(|x| !x.is_finite()))
136    }
137}
138
139/// Per-block gradient accumulator for the full model.
140///
141/// Holds one `BlockGradientSet` per transformer block, plus separate
142/// buffers for LM head, final norm, and embedding gradients.
143///
144/// # VRAM Cost
145///
146/// For 350M model (H=1024, I=4096, D_kv=256, L=24):
147/// - Per block: ~2.8M f32 = ~11.2 MB
148/// - 24 blocks: ~268 MB total
149/// - Non-block (LM head + final norm + embedding): ~67 MB
150/// - Total: ~335 MB CPU RAM (not VRAM — these are CPU-side buffers)
151#[derive(Debug)]
152pub struct PerBlockGradientAccumulator {
153    /// Per-block gradient buffers
154    pub block_grads: Vec<BlockGradientSet>,
155    /// LM head weight gradient [vocab_size * hidden_size]
156    pub lm_head_grad: Vec<f32>,
157    /// Final norm weight gradient [hidden_size]
158    pub final_norm_grad: Vec<f32>,
159    /// Embedding weight gradient [vocab_size * hidden_size]
160    pub embedding_grad: Vec<f32>,
161    /// Number of accumulated micro-batches
162    pub accumulated_count: usize,
163    /// Component sizes per block (cached for wire protocol)
164    pub block_component_sizes: [usize; BLOCK_GRAD_COMPONENTS],
165}
166
167impl PerBlockGradientAccumulator {
168    /// Create a new accumulator with zeroed buffers.
169    ///
170    /// # Arguments
171    /// * `num_blocks` - Number of transformer layers (e.g., 24 for 350M)
172    /// * `block_sizes` - Element count for each of the 9 gradient components per block
173    /// * `vocab_size` - Vocabulary size (for LM head and embedding)
174    /// * `hidden_size` - Hidden dimension (for final norm)
175    pub fn new(
176        num_blocks: usize,
177        block_sizes: [usize; BLOCK_GRAD_COMPONENTS],
178        vocab_size: usize,
179        hidden_size: usize,
180    ) -> Self {
181        let block_grads = (0..num_blocks).map(|_| BlockGradientSet::zeroed(&block_sizes)).collect();
182
183        Self {
184            block_grads,
185            lm_head_grad: vec![0.0; vocab_size * hidden_size],
186            final_norm_grad: vec![0.0; hidden_size],
187            embedding_grad: vec![0.0; vocab_size * hidden_size],
188            accumulated_count: 0,
189            block_component_sizes: block_sizes,
190        }
191    }
192
193    /// Compute the per-block component sizes from model architecture.
194    ///
195    /// # Arguments
196    /// * `hidden_size` - H
197    /// * `kv_hidden_size` - D_kv = (H / num_heads) * num_kv_heads
198    /// * `intermediate_size` - I (FFN intermediate dimension)
199    pub fn compute_block_sizes(
200        hidden_size: usize,
201        kv_hidden_size: usize,
202        intermediate_size: usize,
203    ) -> [usize; BLOCK_GRAD_COMPONENTS] {
204        [
205            hidden_size * hidden_size,       // w_q
206            hidden_size * kv_hidden_size,    // w_k
207            hidden_size * kv_hidden_size,    // w_v
208            hidden_size * hidden_size,       // w_o
209            hidden_size * intermediate_size, // gate
210            hidden_size * intermediate_size, // up
211            intermediate_size * hidden_size, // down
212            hidden_size,                     // input_norm
213            hidden_size,                     // post_attn_norm
214        ]
215    }
216
217    /// Zero all accumulated gradients (call at the start of each step).
218    pub fn zero_all(&mut self) {
219        for block_grad in &mut self.block_grads {
220            block_grad.zero();
221        }
222        self.lm_head_grad.iter_mut().for_each(|x| *x = 0.0);
223        self.final_norm_grad.iter_mut().for_each(|x| *x = 0.0);
224        self.embedding_grad.iter_mut().for_each(|x| *x = 0.0);
225        self.accumulated_count = 0;
226    }
227
228    /// Average accumulated gradients by dividing by the accumulated count.
229    pub fn average(&mut self) {
230        if self.accumulated_count <= 1 {
231            return;
232        }
233        let n = self.accumulated_count as f32;
234        for block_grad in &mut self.block_grads {
235            block_grad.scale(n);
236        }
237        let inv = 1.0 / n;
238        for x in &mut self.lm_head_grad {
239            *x *= inv;
240        }
241        for x in &mut self.final_norm_grad {
242            *x *= inv;
243        }
244        for x in &mut self.embedding_grad {
245            *x *= inv;
246        }
247    }
248
249    /// Check if any block has NaN or Inf gradients (Jidoka).
250    pub fn has_non_finite(&self) -> bool {
251        self.block_grads.iter().any(BlockGradientSet::has_non_finite)
252            || self.lm_head_grad.iter().any(|x| !x.is_finite())
253            || self.final_norm_grad.iter().any(|x| !x.is_finite())
254            || self.embedding_grad.iter().any(|x| !x.is_finite())
255    }
256
257    /// Number of transformer blocks.
258    pub fn num_blocks(&self) -> usize {
259        self.block_grads.len()
260    }
261}
262
263#[cfg(test)]
264mod tests {
265    use super::*;
266
267    #[test]
268    fn test_block_gradient_set_zeroed() {
269        let sizes = [100, 50, 50, 100, 200, 200, 200, 10, 10];
270        let bg = BlockGradientSet::zeroed(&sizes);
271        assert_eq!(bg.components.len(), 9);
272        assert_eq!(bg.total_elements(), 920);
273        assert!(bg.components[0].iter().all(|&x| x == 0.0));
274    }
275
276    #[test]
277    fn test_block_gradient_set_flatten_roundtrip() {
278        let sizes = [4, 2, 2, 4, 8, 8, 8, 1, 1];
279        let mut bg = BlockGradientSet::zeroed(&sizes);
280        // Fill with test data
281        for (i, comp) in bg.components.iter_mut().enumerate() {
282            for (j, val) in comp.iter_mut().enumerate() {
283                *val = (i * 100 + j) as f32;
284            }
285        }
286        let flat = bg.flatten();
287        assert_eq!(flat.len(), 38);
288
289        let sizes_u32 = bg.component_sizes_u32();
290        let reconstructed = BlockGradientSet::from_flat(&flat, &sizes_u32);
291        for (orig, recon) in bg.components.iter().zip(&reconstructed.components) {
292            assert_eq!(orig, recon);
293        }
294    }
295
296    #[test]
297    fn test_block_gradient_set_accumulate() {
298        let sizes = [2, 2, 2, 2, 2, 2, 2, 1, 1];
299        let mut a = BlockGradientSet::zeroed(&sizes);
300        let mut b = BlockGradientSet::zeroed(&sizes);
301        a.components[0] = vec![1.0, 2.0];
302        b.components[0] = vec![3.0, 4.0];
303        a.accumulate(&b);
304        assert_eq!(a.components[0], vec![4.0, 6.0]);
305    }
306
307    #[test]
308    fn test_block_gradient_set_scale() {
309        let sizes = [2, 1, 1, 1, 1, 1, 1, 1, 1];
310        let mut bg = BlockGradientSet::zeroed(&sizes);
311        bg.components[0] = vec![6.0, 9.0];
312        bg.scale(3.0);
313        assert!((bg.components[0][0] - 2.0).abs() < 1e-6);
314        assert!((bg.components[0][1] - 3.0).abs() < 1e-6);
315    }
316
317    #[test]
318    fn test_block_gradient_set_has_non_finite() {
319        let sizes = [2, 1, 1, 1, 1, 1, 1, 1, 1];
320        let mut bg = BlockGradientSet::zeroed(&sizes);
321        assert!(!bg.has_non_finite());
322        bg.components[0][0] = f32::NAN;
323        assert!(bg.has_non_finite());
324    }
325
326    #[test]
327    fn test_accumulator_new() {
328        let sizes = PerBlockGradientAccumulator::compute_block_sizes(1024, 256, 4096);
329        let acc = PerBlockGradientAccumulator::new(24, sizes, 32768, 1024);
330        assert_eq!(acc.num_blocks(), 24);
331        assert_eq!(acc.lm_head_grad.len(), 32768 * 1024);
332        assert_eq!(acc.final_norm_grad.len(), 1024);
333        assert_eq!(acc.embedding_grad.len(), 32768 * 1024);
334    }
335
336    #[test]
337    fn test_accumulator_compute_block_sizes_350m() {
338        // 350M: H=1024, num_heads=16, num_kv_heads=4, kv_dim=256, I=4096
339        let sizes = PerBlockGradientAccumulator::compute_block_sizes(1024, 256, 4096);
340        assert_eq!(sizes[component::W_Q], 1024 * 1024); // 1M
341        assert_eq!(sizes[component::W_K], 1024 * 256); // 256K
342        assert_eq!(sizes[component::W_V], 1024 * 256); // 256K
343        assert_eq!(sizes[component::W_O], 1024 * 1024); // 1M
344        assert_eq!(sizes[component::GATE], 1024 * 4096); // 4M
345        assert_eq!(sizes[component::UP], 1024 * 4096); // 4M
346        assert_eq!(sizes[component::DOWN], 4096 * 1024); // 4M
347        assert_eq!(sizes[component::INPUT_NORM], 1024);
348        assert_eq!(sizes[component::POST_ATTN_NORM], 1024);
349    }
350
351    #[test]
352    fn test_accumulator_zero_all() {
353        let sizes = [2, 1, 1, 1, 1, 1, 1, 1, 1];
354        let mut acc = PerBlockGradientAccumulator::new(2, sizes, 10, 2);
355        acc.block_grads[0].components[0] = vec![1.0, 2.0];
356        acc.lm_head_grad[0] = 5.0;
357        acc.accumulated_count = 3;
358        acc.zero_all();
359        assert!(acc.block_grads[0].components[0].iter().all(|&x| x == 0.0));
360        assert_eq!(acc.lm_head_grad[0], 0.0);
361        assert_eq!(acc.accumulated_count, 0);
362    }
363
364    #[test]
365    fn test_accumulator_has_non_finite() {
366        let sizes = [2, 1, 1, 1, 1, 1, 1, 1, 1];
367        let mut acc = PerBlockGradientAccumulator::new(2, sizes, 10, 2);
368        assert!(!acc.has_non_finite());
369        acc.lm_head_grad[0] = f32::INFINITY;
370        assert!(acc.has_non_finite());
371    }
372}