entrenar/train/transformer_trainer/
zero.rs1#[derive(Debug, Clone)]
26pub struct OptimizerShard {
27 pub rank: usize,
29 pub world_size: usize,
31 pub param_start: usize,
33 pub param_end: usize,
35 pub total_params: usize,
37}
38
39impl OptimizerShard {
40 pub fn for_rank(rank: usize, world_size: usize, total_params: usize) -> Self {
51 let shard_size = total_params / world_size;
52 let remainder = total_params % world_size;
53
54 let param_start = if rank < remainder {
56 rank * (shard_size + 1)
57 } else {
58 remainder * (shard_size + 1) + (rank - remainder) * shard_size
59 };
60
61 let param_end =
62 if rank < remainder { param_start + shard_size + 1 } else { param_start + shard_size };
63
64 Self { rank, world_size, param_start, param_end, total_params }
65 }
66
67 pub fn shard_size(&self) -> usize {
69 self.param_end - self.param_start
70 }
71
72 pub fn owns_param(&self, param_idx: usize) -> bool {
74 param_idx >= self.param_start && param_idx < self.param_end
75 }
76
77 pub fn memory_savings(&self) -> f64 {
81 1.0 - (1.0 / self.world_size as f64)
82 }
83
84 pub fn shard_memory_bytes(&self) -> usize {
88 self.shard_size() * 2 * std::mem::size_of::<f32>()
89 }
90
91 pub fn full_memory_bytes(&self) -> usize {
93 self.total_params * 2 * std::mem::size_of::<f32>()
94 }
95}
96
97#[derive(Debug, Clone)]
104pub struct ZeroShardMap {
105 pub block_owners: Vec<usize>,
108 pub lm_head_owner: usize,
110 pub final_norm_owner: usize,
112 pub embedding_owner: usize,
114 pub world_size: usize,
116}
117
118impl ZeroShardMap {
119 pub fn round_robin(num_blocks: usize, world_size: usize) -> Self {
124 let block_owners: Vec<usize> = (0..num_blocks).map(|i| i % world_size).collect();
125
126 Self { block_owners, lm_head_owner: 0, final_norm_owner: 0, embedding_owner: 0, world_size }
127 }
128
129 pub fn contiguous(num_blocks: usize, world_size: usize) -> Self {
134 let blocks_per_worker = num_blocks / world_size;
135 let remainder = num_blocks % world_size;
136 let mut block_owners = Vec::with_capacity(num_blocks);
137
138 for rank in 0..world_size {
139 let count = blocks_per_worker + usize::from(rank < remainder);
140 for _ in 0..count {
141 block_owners.push(rank);
142 }
143 }
144
145 Self { block_owners, lm_head_owner: 0, final_norm_owner: 0, embedding_owner: 0, world_size }
146 }
147
148 pub fn block_owner(&self, block_idx: usize) -> usize {
150 self.block_owners[block_idx]
151 }
152
153 pub fn rank_owns_block(&self, rank: usize, block_idx: usize) -> bool {
155 self.block_owners[block_idx] == rank
156 }
157
158 pub fn blocks_for_rank(&self, rank: usize) -> Vec<usize> {
160 self.block_owners
161 .iter()
162 .enumerate()
163 .filter(|(_, &owner)| owner == rank)
164 .map(|(i, _)| i)
165 .collect()
166 }
167
168 pub fn num_blocks_for_rank(&self, rank: usize) -> usize {
170 self.block_owners.iter().filter(|&&owner| owner == rank).count()
171 }
172
173 pub fn memory_fraction_for_rank(&self, rank: usize) -> f64 {
177 let owned = self.num_blocks_for_rank(rank) as f64;
178 let total = self.block_owners.len() as f64;
179 owned / total
180 }
181}
182
183#[cfg(test)]
184mod tests {
185 use super::*;
186
187 #[test]
188 fn test_optimizer_shard_basic() {
189 let shard = OptimizerShard::for_rank(0, 4, 100);
191 assert_eq!(shard.shard_size(), 25);
192 assert_eq!(shard.param_start, 0);
193 assert_eq!(shard.param_end, 25);
194 assert!(shard.owns_param(0));
195 assert!(shard.owns_param(24));
196 assert!(!shard.owns_param(25));
197 }
198
199 #[test]
200 fn test_optimizer_shard_remainder() {
201 let s0 = OptimizerShard::for_rank(0, 3, 10);
203 let s1 = OptimizerShard::for_rank(1, 3, 10);
204 let s2 = OptimizerShard::for_rank(2, 3, 10);
205
206 assert_eq!(s0.shard_size(), 4); assert_eq!(s1.shard_size(), 3);
208 assert_eq!(s2.shard_size(), 3);
209
210 assert_eq!(s0.param_start, 0);
212 assert_eq!(s0.param_end, 4);
213 assert_eq!(s1.param_start, 4);
214 assert_eq!(s1.param_end, 7);
215 assert_eq!(s2.param_start, 7);
216 assert_eq!(s2.param_end, 10);
217 }
218
219 #[test]
220 fn test_optimizer_shard_completeness() {
221 let total = 1_000_003; let world_size = 7;
224 let mut covered = vec![false; total];
225 for rank in 0..world_size {
226 let shard = OptimizerShard::for_rank(rank, world_size, total);
227 for i in shard.param_start..shard.param_end {
228 assert!(!covered[i], "param {i} covered by multiple shards");
229 covered[i] = true;
230 }
231 }
232 assert!(covered.iter().all(|&c| c), "not all params covered");
233 }
234
235 #[test]
236 fn test_optimizer_shard_memory_savings() {
237 let shard = OptimizerShard::for_rank(0, 4, 1_000_000);
238 assert!((shard.memory_savings() - 0.75).abs() < 1e-10);
239 assert_eq!(shard.full_memory_bytes(), 8_000_000);
241 assert_eq!(shard.shard_memory_bytes(), 2_000_000);
243 }
244
245 #[test]
246 fn test_zero_shard_map_round_robin() {
247 let map = ZeroShardMap::round_robin(24, 4);
248 assert_eq!(map.block_owner(0), 0);
249 assert_eq!(map.block_owner(1), 1);
250 assert_eq!(map.block_owner(2), 2);
251 assert_eq!(map.block_owner(3), 3);
252 assert_eq!(map.block_owner(4), 0);
253
254 assert_eq!(map.num_blocks_for_rank(0), 6);
255 assert_eq!(map.num_blocks_for_rank(1), 6);
256
257 let blocks = map.blocks_for_rank(0);
258 assert_eq!(blocks, vec![0, 4, 8, 12, 16, 20]);
259 }
260
261 #[test]
262 fn test_zero_shard_map_contiguous() {
263 let map = ZeroShardMap::contiguous(24, 4);
264 assert_eq!(map.blocks_for_rank(0), vec![0, 1, 2, 3, 4, 5]);
266 assert_eq!(map.blocks_for_rank(1), vec![6, 7, 8, 9, 10, 11]);
267 assert_eq!(map.blocks_for_rank(2), vec![12, 13, 14, 15, 16, 17]);
268 assert_eq!(map.blocks_for_rank(3), vec![18, 19, 20, 21, 22, 23]);
269 }
270
271 #[test]
272 fn test_zero_shard_map_contiguous_uneven() {
273 let map = ZeroShardMap::contiguous(10, 3);
274 assert_eq!(map.num_blocks_for_rank(0), 4);
276 assert_eq!(map.num_blocks_for_rank(1), 3);
277 assert_eq!(map.num_blocks_for_rank(2), 3);
278
279 let total: usize = (0..3).map(|r| map.num_blocks_for_rank(r)).sum();
281 assert_eq!(total, 10);
282 }
283
284 #[test]
285 fn test_zero_shard_map_memory_fraction() {
286 let map = ZeroShardMap::round_robin(24, 4);
287 let frac = map.memory_fraction_for_rank(0);
288 assert!((frac - 0.25).abs() < 1e-10);
289 }
290
291 #[test]
292 fn test_zero_shard_map_rank_owns_block() {
293 let map = ZeroShardMap::contiguous(12, 3);
294 assert!(map.rank_owns_block(0, 0));
295 assert!(map.rank_owns_block(0, 3));
296 assert!(!map.rank_owns_block(0, 4));
297 assert!(map.rank_owns_block(1, 4));
298 }
299
300 #[test]
301 fn test_zero_shard_350m() {
302 let map = ZeroShardMap::contiguous(24, 4);
304 for rank in 0..4 {
306 assert_eq!(map.num_blocks_for_rank(rank), 6);
307 let frac = map.memory_fraction_for_rank(rank);
308 assert!((frac - 0.25).abs() < 1e-10);
309 }
310 }
311}