entrenar/train/transformer_trainer/
grad_accumulator.rs1#![allow(dead_code)]
2pub const BLOCK_GRAD_COMPONENTS: usize = 9;
22
23pub mod component {
25 pub const W_Q: usize = 0;
26 pub const W_K: usize = 1;
27 pub const W_V: usize = 2;
28 pub const W_O: usize = 3;
29 pub const GATE: usize = 4;
30 pub const UP: usize = 5;
31 pub const DOWN: usize = 6;
32 pub const INPUT_NORM: usize = 7;
33 pub const POST_ATTN_NORM: usize = 8;
34}
35
36pub mod non_block {
38 pub const LM_HEAD: u8 = 0;
39 pub const FINAL_NORM: u8 = 1;
40 pub const EMBEDDING: u8 = 2;
41}
42
43#[derive(Debug, Clone)]
48pub struct BlockGradientSet {
49 pub components: Vec<Vec<f32>>,
51}
52
53impl BlockGradientSet {
54 pub fn zeroed(sizes: &[usize; BLOCK_GRAD_COMPONENTS]) -> Self {
59 let components = sizes.iter().map(|&sz| vec![0.0f32; sz]).collect();
60 Self { components }
61 }
62
63 pub fn total_elements(&self) -> usize {
65 self.components.iter().map(Vec::len).sum()
66 }
67
68 pub fn component_sizes_u32(&self) -> Vec<u32> {
70 self.components.iter().map(|c| c.len() as u32).collect()
71 }
72
73 pub fn flatten(&self) -> Vec<f32> {
75 let total = self.total_elements();
76 let mut flat = Vec::with_capacity(total);
77 for comp in &self.components {
78 flat.extend_from_slice(comp);
79 }
80 flat
81 }
82
83 pub fn from_flat(flat: &[f32], sizes: &[u32]) -> Self {
88 let total: usize = sizes.iter().map(|&s| s as usize).sum();
89 assert_eq!(flat.len(), total, "flat gradient length mismatch");
90 let mut components = Vec::with_capacity(sizes.len());
91 let mut offset = 0;
92 for &sz in sizes {
93 let sz = sz as usize;
94 components.push(flat[offset..offset + sz].to_vec());
95 offset += sz;
96 }
97 Self { components }
98 }
99
100 pub fn zero(&mut self) {
102 for comp in &mut self.components {
103 for x in comp.iter_mut() {
104 *x = 0.0;
105 }
106 }
107 }
108
109 pub fn accumulate(&mut self, other: &BlockGradientSet) {
114 assert_eq!(self.components.len(), other.components.len());
115 for (dst, src) in self.components.iter_mut().zip(&other.components) {
116 assert_eq!(dst.len(), src.len(), "component size mismatch");
117 for (d, s) in dst.iter_mut().zip(src) {
118 *d += s;
119 }
120 }
121 }
122
123 pub fn scale(&mut self, divisor: f32) {
125 let inv = 1.0 / divisor;
126 for comp in &mut self.components {
127 for x in comp.iter_mut() {
128 *x *= inv;
129 }
130 }
131 }
132
133 pub fn has_non_finite(&self) -> bool {
135 self.components.iter().any(|comp| comp.iter().any(|x| !x.is_finite()))
136 }
137}
138
139#[derive(Debug)]
152pub struct PerBlockGradientAccumulator {
153 pub block_grads: Vec<BlockGradientSet>,
155 pub lm_head_grad: Vec<f32>,
157 pub final_norm_grad: Vec<f32>,
159 pub embedding_grad: Vec<f32>,
161 pub accumulated_count: usize,
163 pub block_component_sizes: [usize; BLOCK_GRAD_COMPONENTS],
165}
166
167impl PerBlockGradientAccumulator {
168 pub fn new(
176 num_blocks: usize,
177 block_sizes: [usize; BLOCK_GRAD_COMPONENTS],
178 vocab_size: usize,
179 hidden_size: usize,
180 ) -> Self {
181 let block_grads = (0..num_blocks).map(|_| BlockGradientSet::zeroed(&block_sizes)).collect();
182
183 Self {
184 block_grads,
185 lm_head_grad: vec![0.0; vocab_size * hidden_size],
186 final_norm_grad: vec![0.0; hidden_size],
187 embedding_grad: vec![0.0; vocab_size * hidden_size],
188 accumulated_count: 0,
189 block_component_sizes: block_sizes,
190 }
191 }
192
193 pub fn compute_block_sizes(
200 hidden_size: usize,
201 kv_hidden_size: usize,
202 intermediate_size: usize,
203 ) -> [usize; BLOCK_GRAD_COMPONENTS] {
204 [
205 hidden_size * hidden_size, hidden_size * kv_hidden_size, hidden_size * kv_hidden_size, hidden_size * hidden_size, hidden_size * intermediate_size, hidden_size * intermediate_size, intermediate_size * hidden_size, hidden_size, hidden_size, ]
215 }
216
217 pub fn zero_all(&mut self) {
219 for block_grad in &mut self.block_grads {
220 block_grad.zero();
221 }
222 self.lm_head_grad.iter_mut().for_each(|x| *x = 0.0);
223 self.final_norm_grad.iter_mut().for_each(|x| *x = 0.0);
224 self.embedding_grad.iter_mut().for_each(|x| *x = 0.0);
225 self.accumulated_count = 0;
226 }
227
228 pub fn average(&mut self) {
230 if self.accumulated_count <= 1 {
231 return;
232 }
233 let n = self.accumulated_count as f32;
234 for block_grad in &mut self.block_grads {
235 block_grad.scale(n);
236 }
237 let inv = 1.0 / n;
238 for x in &mut self.lm_head_grad {
239 *x *= inv;
240 }
241 for x in &mut self.final_norm_grad {
242 *x *= inv;
243 }
244 for x in &mut self.embedding_grad {
245 *x *= inv;
246 }
247 }
248
249 pub fn has_non_finite(&self) -> bool {
251 self.block_grads.iter().any(BlockGradientSet::has_non_finite)
252 || self.lm_head_grad.iter().any(|x| !x.is_finite())
253 || self.final_norm_grad.iter().any(|x| !x.is_finite())
254 || self.embedding_grad.iter().any(|x| !x.is_finite())
255 }
256
257 pub fn num_blocks(&self) -> usize {
259 self.block_grads.len()
260 }
261}
262
263#[cfg(test)]
264mod tests {
265 use super::*;
266
267 #[test]
268 fn test_block_gradient_set_zeroed() {
269 let sizes = [100, 50, 50, 100, 200, 200, 200, 10, 10];
270 let bg = BlockGradientSet::zeroed(&sizes);
271 assert_eq!(bg.components.len(), 9);
272 assert_eq!(bg.total_elements(), 920);
273 assert!(bg.components[0].iter().all(|&x| x == 0.0));
274 }
275
276 #[test]
277 fn test_block_gradient_set_flatten_roundtrip() {
278 let sizes = [4, 2, 2, 4, 8, 8, 8, 1, 1];
279 let mut bg = BlockGradientSet::zeroed(&sizes);
280 for (i, comp) in bg.components.iter_mut().enumerate() {
282 for (j, val) in comp.iter_mut().enumerate() {
283 *val = (i * 100 + j) as f32;
284 }
285 }
286 let flat = bg.flatten();
287 assert_eq!(flat.len(), 38);
288
289 let sizes_u32 = bg.component_sizes_u32();
290 let reconstructed = BlockGradientSet::from_flat(&flat, &sizes_u32);
291 for (orig, recon) in bg.components.iter().zip(&reconstructed.components) {
292 assert_eq!(orig, recon);
293 }
294 }
295
296 #[test]
297 fn test_block_gradient_set_accumulate() {
298 let sizes = [2, 2, 2, 2, 2, 2, 2, 1, 1];
299 let mut a = BlockGradientSet::zeroed(&sizes);
300 let mut b = BlockGradientSet::zeroed(&sizes);
301 a.components[0] = vec![1.0, 2.0];
302 b.components[0] = vec![3.0, 4.0];
303 a.accumulate(&b);
304 assert_eq!(a.components[0], vec![4.0, 6.0]);
305 }
306
307 #[test]
308 fn test_block_gradient_set_scale() {
309 let sizes = [2, 1, 1, 1, 1, 1, 1, 1, 1];
310 let mut bg = BlockGradientSet::zeroed(&sizes);
311 bg.components[0] = vec![6.0, 9.0];
312 bg.scale(3.0);
313 assert!((bg.components[0][0] - 2.0).abs() < 1e-6);
314 assert!((bg.components[0][1] - 3.0).abs() < 1e-6);
315 }
316
317 #[test]
318 fn test_block_gradient_set_has_non_finite() {
319 let sizes = [2, 1, 1, 1, 1, 1, 1, 1, 1];
320 let mut bg = BlockGradientSet::zeroed(&sizes);
321 assert!(!bg.has_non_finite());
322 bg.components[0][0] = f32::NAN;
323 assert!(bg.has_non_finite());
324 }
325
326 #[test]
327 fn test_accumulator_new() {
328 let sizes = PerBlockGradientAccumulator::compute_block_sizes(1024, 256, 4096);
329 let acc = PerBlockGradientAccumulator::new(24, sizes, 32768, 1024);
330 assert_eq!(acc.num_blocks(), 24);
331 assert_eq!(acc.lm_head_grad.len(), 32768 * 1024);
332 assert_eq!(acc.final_norm_grad.len(), 1024);
333 assert_eq!(acc.embedding_grad.len(), 32768 * 1024);
334 }
335
336 #[test]
337 fn test_accumulator_compute_block_sizes_350m() {
338 let sizes = PerBlockGradientAccumulator::compute_block_sizes(1024, 256, 4096);
340 assert_eq!(sizes[component::W_Q], 1024 * 1024); assert_eq!(sizes[component::W_K], 1024 * 256); assert_eq!(sizes[component::W_V], 1024 * 256); assert_eq!(sizes[component::W_O], 1024 * 1024); assert_eq!(sizes[component::GATE], 1024 * 4096); assert_eq!(sizes[component::UP], 1024 * 4096); assert_eq!(sizes[component::DOWN], 4096 * 1024); assert_eq!(sizes[component::INPUT_NORM], 1024);
348 assert_eq!(sizes[component::POST_ATTN_NORM], 1024);
349 }
350
351 #[test]
352 fn test_accumulator_zero_all() {
353 let sizes = [2, 1, 1, 1, 1, 1, 1, 1, 1];
354 let mut acc = PerBlockGradientAccumulator::new(2, sizes, 10, 2);
355 acc.block_grads[0].components[0] = vec![1.0, 2.0];
356 acc.lm_head_grad[0] = 5.0;
357 acc.accumulated_count = 3;
358 acc.zero_all();
359 assert!(acc.block_grads[0].components[0].iter().all(|&x| x == 0.0));
360 assert_eq!(acc.lm_head_grad[0], 0.0);
361 assert_eq!(acc.accumulated_count, 0);
362 }
363
364 #[test]
365 fn test_accumulator_has_non_finite() {
366 let sizes = [2, 1, 1, 1, 1, 1, 1, 1, 1];
367 let mut acc = PerBlockGradientAccumulator::new(2, sizes, 10, 2);
368 assert!(!acc.has_non_finite());
369 acc.lm_head_grad[0] = f32::INFINITY;
370 assert!(acc.has_non_finite());
371 }
372}