Skip to main content

entrenar/train/transformer_trainer/
zero.rs

1//! ZeRO-1 optimizer state sharding for distributed pretraining.
2//!
3//! Implements optimizer state partitioning (ZeRO Stage 1) where each worker
4//! holds only 1/N of the optimizer states (Adam m and v vectors). After
5//! gradient AllReduce, each worker runs the optimizer step only for its
6//! assigned shard, then all-gathers the updated weights.
7//!
8//! # Memory Savings
9//!
10//! For 350M model: AdamW stores m + v (2 × param_count f32).
11//! - Without ZeRO: each worker holds ~2.8 GB optimizer state
12//! - With ZeRO-1 (N=2): each worker holds ~1.4 GB
13//! - With ZeRO-1 (N=4): each worker holds ~0.7 GB
14//!
15//! # Contract (C-ZERO-001)
16//!
17//! - After all-gather, all workers hold identical updated weights
18//! - Each worker's optimizer shard produces the same result as full optimizer
19//! - Shards are contiguous and non-overlapping, covering all parameters
20
21/// Optimizer state shard assignment for a single worker.
22///
23/// Each worker owns a contiguous range of parameter indices and holds
24/// only the Adam m/v states for those parameters.
25#[derive(Debug, Clone)]
26pub struct OptimizerShard {
27    /// This worker's rank
28    pub rank: usize,
29    /// Total workers in the ZeRO group
30    pub world_size: usize,
31    /// Start index (inclusive) in the flattened parameter vector
32    pub param_start: usize,
33    /// End index (exclusive) in the flattened parameter vector
34    pub param_end: usize,
35    /// Total parameter count across all shards
36    pub total_params: usize,
37}
38
39impl OptimizerShard {
40    /// Compute shard assignment for a given rank.
41    ///
42    /// Divides `total_params` into `world_size` contiguous shards.
43    /// Last shard absorbs any remainder.
44    ///
45    /// # Contract (C-ZERO-001)
46    ///
47    /// - Union of all shards == [0, total_params)
48    /// - Intersection of any two shards == empty
49    /// - Each shard size within 1 of floor(total_params / world_size)
50    pub fn for_rank(rank: usize, world_size: usize, total_params: usize) -> Self {
51        let shard_size = total_params / world_size;
52        let remainder = total_params % world_size;
53
54        // First `remainder` ranks get one extra element
55        let param_start = if rank < remainder {
56            rank * (shard_size + 1)
57        } else {
58            remainder * (shard_size + 1) + (rank - remainder) * shard_size
59        };
60
61        let param_end =
62            if rank < remainder { param_start + shard_size + 1 } else { param_start + shard_size };
63
64        Self { rank, world_size, param_start, param_end, total_params }
65    }
66
67    /// Number of parameters in this shard.
68    pub fn shard_size(&self) -> usize {
69        self.param_end - self.param_start
70    }
71
72    /// Check if a parameter index belongs to this shard.
73    pub fn owns_param(&self, param_idx: usize) -> bool {
74        param_idx >= self.param_start && param_idx < self.param_end
75    }
76
77    /// Estimated memory savings ratio compared to full replication.
78    ///
79    /// Returns fraction of memory saved (e.g., 0.5 for 2 workers).
80    pub fn memory_savings(&self) -> f64 {
81        1.0 - (1.0 / self.world_size as f64)
82    }
83
84    /// Estimated optimizer memory for this shard in bytes.
85    ///
86    /// AdamW stores m + v (2 × shard_size × sizeof(f32)).
87    pub fn shard_memory_bytes(&self) -> usize {
88        self.shard_size() * 2 * std::mem::size_of::<f32>()
89    }
90
91    /// Full optimizer memory without sharding in bytes.
92    pub fn full_memory_bytes(&self) -> usize {
93        self.total_params * 2 * std::mem::size_of::<f32>()
94    }
95}
96
97/// Block-level optimizer shard map.
98///
99/// Maps transformer blocks to worker ranks. In ZeRO-1, each block's
100/// optimizer state is owned by exactly one worker. The owner runs the
101/// optimizer step for that block after gradient AllReduce, then broadcasts
102/// updated weights.
103#[derive(Debug, Clone)]
104pub struct ZeroShardMap {
105    /// Which worker owns each block's optimizer state.
106    /// `block_owners[i]` = rank that owns block i.
107    pub block_owners: Vec<usize>,
108    /// Which worker owns the LM head optimizer state.
109    pub lm_head_owner: usize,
110    /// Which worker owns the final norm optimizer state.
111    pub final_norm_owner: usize,
112    /// Which worker owns the embedding optimizer state.
113    pub embedding_owner: usize,
114    /// World size
115    pub world_size: usize,
116}
117
118impl ZeroShardMap {
119    /// Create a shard map distributing blocks round-robin across workers.
120    ///
121    /// Non-block components (LM head, final norm, embedding) are assigned
122    /// to rank 0 by default since they're relatively small.
123    pub fn round_robin(num_blocks: usize, world_size: usize) -> Self {
124        let block_owners: Vec<usize> = (0..num_blocks).map(|i| i % world_size).collect();
125
126        Self { block_owners, lm_head_owner: 0, final_norm_owner: 0, embedding_owner: 0, world_size }
127    }
128
129    /// Create a shard map distributing blocks in contiguous chunks.
130    ///
131    /// Preferred for pipeline parallelism compatibility: worker 0 gets
132    /// blocks 0..N/W, worker 1 gets N/W..2N/W, etc.
133    pub fn contiguous(num_blocks: usize, world_size: usize) -> Self {
134        let blocks_per_worker = num_blocks / world_size;
135        let remainder = num_blocks % world_size;
136        let mut block_owners = Vec::with_capacity(num_blocks);
137
138        for rank in 0..world_size {
139            let count = blocks_per_worker + usize::from(rank < remainder);
140            for _ in 0..count {
141                block_owners.push(rank);
142            }
143        }
144
145        Self { block_owners, lm_head_owner: 0, final_norm_owner: 0, embedding_owner: 0, world_size }
146    }
147
148    /// Get the owning rank for a given block index.
149    pub fn block_owner(&self, block_idx: usize) -> usize {
150        self.block_owners[block_idx]
151    }
152
153    /// Check if this rank owns a given block's optimizer state.
154    pub fn rank_owns_block(&self, rank: usize, block_idx: usize) -> bool {
155        self.block_owners[block_idx] == rank
156    }
157
158    /// Get all block indices owned by a given rank.
159    pub fn blocks_for_rank(&self, rank: usize) -> Vec<usize> {
160        self.block_owners
161            .iter()
162            .enumerate()
163            .filter(|(_, &owner)| owner == rank)
164            .map(|(i, _)| i)
165            .collect()
166    }
167
168    /// Number of blocks owned by a given rank.
169    pub fn num_blocks_for_rank(&self, rank: usize) -> usize {
170        self.block_owners.iter().filter(|&&owner| owner == rank).count()
171    }
172
173    /// Memory savings for a given rank (ratio of blocks owned / total blocks).
174    ///
175    /// Returns the fraction of optimizer memory this rank holds.
176    pub fn memory_fraction_for_rank(&self, rank: usize) -> f64 {
177        let owned = self.num_blocks_for_rank(rank) as f64;
178        let total = self.block_owners.len() as f64;
179        owned / total
180    }
181}
182
183#[cfg(test)]
184mod tests {
185    use super::*;
186
187    #[test]
188    fn test_optimizer_shard_basic() {
189        // 100 params, 4 workers → 25 each
190        let shard = OptimizerShard::for_rank(0, 4, 100);
191        assert_eq!(shard.shard_size(), 25);
192        assert_eq!(shard.param_start, 0);
193        assert_eq!(shard.param_end, 25);
194        assert!(shard.owns_param(0));
195        assert!(shard.owns_param(24));
196        assert!(!shard.owns_param(25));
197    }
198
199    #[test]
200    fn test_optimizer_shard_remainder() {
201        // 10 params, 3 workers → 4, 3, 3
202        let s0 = OptimizerShard::for_rank(0, 3, 10);
203        let s1 = OptimizerShard::for_rank(1, 3, 10);
204        let s2 = OptimizerShard::for_rank(2, 3, 10);
205
206        assert_eq!(s0.shard_size(), 4); // gets extra
207        assert_eq!(s1.shard_size(), 3);
208        assert_eq!(s2.shard_size(), 3);
209
210        // Non-overlapping and complete
211        assert_eq!(s0.param_start, 0);
212        assert_eq!(s0.param_end, 4);
213        assert_eq!(s1.param_start, 4);
214        assert_eq!(s1.param_end, 7);
215        assert_eq!(s2.param_start, 7);
216        assert_eq!(s2.param_end, 10);
217    }
218
219    #[test]
220    fn test_optimizer_shard_completeness() {
221        // C-ZERO-001: union of shards == full range
222        let total = 1_000_003; // prime number to test remainder handling
223        let world_size = 7;
224        let mut covered = vec![false; total];
225        for rank in 0..world_size {
226            let shard = OptimizerShard::for_rank(rank, world_size, total);
227            for i in shard.param_start..shard.param_end {
228                assert!(!covered[i], "param {i} covered by multiple shards");
229                covered[i] = true;
230            }
231        }
232        assert!(covered.iter().all(|&c| c), "not all params covered");
233    }
234
235    #[test]
236    fn test_optimizer_shard_memory_savings() {
237        let shard = OptimizerShard::for_rank(0, 4, 1_000_000);
238        assert!((shard.memory_savings() - 0.75).abs() < 1e-10);
239        // Full memory: 1M × 2 × 4 = 8 MB
240        assert_eq!(shard.full_memory_bytes(), 8_000_000);
241        // Shard memory: 250K × 2 × 4 = 2 MB
242        assert_eq!(shard.shard_memory_bytes(), 2_000_000);
243    }
244
245    #[test]
246    fn test_zero_shard_map_round_robin() {
247        let map = ZeroShardMap::round_robin(24, 4);
248        assert_eq!(map.block_owner(0), 0);
249        assert_eq!(map.block_owner(1), 1);
250        assert_eq!(map.block_owner(2), 2);
251        assert_eq!(map.block_owner(3), 3);
252        assert_eq!(map.block_owner(4), 0);
253
254        assert_eq!(map.num_blocks_for_rank(0), 6);
255        assert_eq!(map.num_blocks_for_rank(1), 6);
256
257        let blocks = map.blocks_for_rank(0);
258        assert_eq!(blocks, vec![0, 4, 8, 12, 16, 20]);
259    }
260
261    #[test]
262    fn test_zero_shard_map_contiguous() {
263        let map = ZeroShardMap::contiguous(24, 4);
264        // 24 blocks / 4 workers = 6 each
265        assert_eq!(map.blocks_for_rank(0), vec![0, 1, 2, 3, 4, 5]);
266        assert_eq!(map.blocks_for_rank(1), vec![6, 7, 8, 9, 10, 11]);
267        assert_eq!(map.blocks_for_rank(2), vec![12, 13, 14, 15, 16, 17]);
268        assert_eq!(map.blocks_for_rank(3), vec![18, 19, 20, 21, 22, 23]);
269    }
270
271    #[test]
272    fn test_zero_shard_map_contiguous_uneven() {
273        let map = ZeroShardMap::contiguous(10, 3);
274        // 10/3 = 3 rem 1 → worker 0 gets 4, workers 1,2 get 3
275        assert_eq!(map.num_blocks_for_rank(0), 4);
276        assert_eq!(map.num_blocks_for_rank(1), 3);
277        assert_eq!(map.num_blocks_for_rank(2), 3);
278
279        // All blocks covered
280        let total: usize = (0..3).map(|r| map.num_blocks_for_rank(r)).sum();
281        assert_eq!(total, 10);
282    }
283
284    #[test]
285    fn test_zero_shard_map_memory_fraction() {
286        let map = ZeroShardMap::round_robin(24, 4);
287        let frac = map.memory_fraction_for_rank(0);
288        assert!((frac - 0.25).abs() < 1e-10);
289    }
290
291    #[test]
292    fn test_zero_shard_map_rank_owns_block() {
293        let map = ZeroShardMap::contiguous(12, 3);
294        assert!(map.rank_owns_block(0, 0));
295        assert!(map.rank_owns_block(0, 3));
296        assert!(!map.rank_owns_block(0, 4));
297        assert!(map.rank_owns_block(1, 4));
298    }
299
300    #[test]
301    fn test_zero_shard_350m() {
302        // 350M model: 24 blocks, 4 GPUs
303        let map = ZeroShardMap::contiguous(24, 4);
304        // Each GPU owns 6 blocks → 25% of optimizer memory
305        for rank in 0..4 {
306            assert_eq!(map.num_blocks_for_rank(rank), 6);
307            let frac = map.memory_fraction_for_rank(rank);
308            assert!((frac - 0.25).abs() < 1e-10);
309        }
310    }
311}