1use ghostflow_core::Tensor;
11use std::collections::HashMap;
12
13#[derive(Debug, Clone, Copy, PartialEq)]
15pub enum ZeRoStage {
16 Stage1,
18 Stage2,
20 Stage3,
22}
23
24#[derive(Debug, Clone)]
26pub struct ZeRoConfig {
27 pub stage: ZeRoStage,
29 pub world_size: usize,
31 pub rank: usize,
33 pub cpu_offload: bool,
35 pub nvme_offload: bool,
37 pub overlap_comm: bool,
39 pub bucket_size: usize,
41}
42
43impl Default for ZeRoConfig {
44 fn default() -> Self {
45 ZeRoConfig {
46 stage: ZeRoStage::Stage2,
47 world_size: 1,
48 rank: 0,
49 cpu_offload: false,
50 nvme_offload: false,
51 overlap_comm: true,
52 bucket_size: 25_000_000, }
54 }
55}
56
57impl ZeRoConfig {
58 pub fn stage1(world_size: usize, rank: usize) -> Self {
60 ZeRoConfig {
61 stage: ZeRoStage::Stage1,
62 world_size,
63 rank,
64 ..Default::default()
65 }
66 }
67
68 pub fn stage2(world_size: usize, rank: usize) -> Self {
70 ZeRoConfig {
71 stage: ZeRoStage::Stage2,
72 world_size,
73 rank,
74 ..Default::default()
75 }
76 }
77
78 pub fn stage3(world_size: usize, rank: usize) -> Self {
80 ZeRoConfig {
81 stage: ZeRoStage::Stage3,
82 world_size,
83 rank,
84 ..Default::default()
85 }
86 }
87
88 pub fn with_offload(mut self, cpu: bool, nvme: bool) -> Self {
90 self.cpu_offload = cpu;
91 self.nvme_offload = nvme;
92 self
93 }
94}
95
96#[derive(Debug, Clone)]
98pub struct ParameterPartition {
99 pub name: String,
101 pub owner_rank: usize,
103 pub start_idx: usize,
105 pub end_idx: usize,
107 pub shape: Vec<usize>,
109}
110
111pub struct ZeRoOptimizer {
113 config: ZeRoConfig,
114 partitioned_params: HashMap<String, Tensor>,
116 partitioned_grads: HashMap<String, Tensor>,
118 partitioned_states: HashMap<String, HashMap<String, Tensor>>,
120 param_partitions: Vec<ParameterPartition>,
122 gradient_buckets: Vec<Vec<String>>,
124 comm_buffer: Vec<f32>,
126 cpu_buffer: HashMap<String, Vec<f32>>,
128 learning_rate: f32,
130}
131
132impl ZeRoOptimizer {
133 pub fn new(config: ZeRoConfig, learning_rate: f32) -> Self {
135 ZeRoOptimizer {
136 config,
137 partitioned_params: HashMap::new(),
138 partitioned_grads: HashMap::new(),
139 partitioned_states: HashMap::new(),
140 param_partitions: Vec::new(),
141 gradient_buckets: Vec::new(),
142 comm_buffer: Vec::new(),
143 cpu_buffer: HashMap::new(),
144 learning_rate,
145 }
146 }
147
148 pub fn partition_parameters(&mut self, params: &HashMap<String, Tensor>) -> Result<(), String> {
150 let total_params: usize = params.values()
151 .map(|t| t.data_f32().len())
152 .sum();
153
154 let params_per_rank = (total_params + self.config.world_size - 1) / self.config.world_size;
155
156 let mut current_idx = 0;
157
158 for (name, tensor) in params {
159 let param_size = tensor.data_f32().len();
160 let start_idx = current_idx;
161 let end_idx = current_idx + param_size;
162
163 let owner_rank = current_idx / params_per_rank;
165
166 let partition = ParameterPartition {
167 name: name.clone(),
168 owner_rank,
169 start_idx,
170 end_idx,
171 shape: tensor.dims().to_vec(),
172 };
173
174 self.param_partitions.push(partition);
175
176 if owner_rank == self.config.rank {
178 self.partitioned_params.insert(name.clone(), tensor.clone());
179
180 let dims = tensor.dims();
182 let size = tensor.data_f32().len();
183 let zeros_data = vec![0.0f32; size];
184
185 let mut states = HashMap::new();
186 states.insert("momentum".to_string(), Tensor::from_slice(&zeros_data, dims).unwrap());
187 states.insert("variance".to_string(), Tensor::from_slice(&zeros_data, dims).unwrap());
188 self.partitioned_states.insert(name.clone(), states);
189 }
190
191 current_idx = end_idx;
192 }
193
194 Ok(())
195 }
196
197 pub fn partition_gradients(&mut self, grads: &HashMap<String, Tensor>) -> Result<(), String> {
199 if self.config.stage == ZeRoStage::Stage1 {
200 for (name, grad) in grads {
202 self.partitioned_grads.insert(name.clone(), grad.clone());
203 }
204 return Ok(());
205 }
206
207 for partition in &self.param_partitions {
209 if partition.owner_rank == self.config.rank {
210 if let Some(grad) = grads.get(&partition.name) {
211 self.partitioned_grads.insert(partition.name.clone(), grad.clone());
212 }
213 }
214 }
215
216 Ok(())
217 }
218
219 pub fn reduce_scatter_gradients(&mut self) -> Result<(), String> {
221 let grad_names: Vec<String> = self.partitioned_grads.keys().cloned().collect();
225
226 for name in grad_names {
227 if let Some(grad) = self.partitioned_grads.get(&name) {
228 let data = grad.data_f32();
230 let averaged: Vec<f32> = data.iter()
231 .map(|&x| x / self.config.world_size as f32)
232 .collect();
233
234 let averaged_grad = Tensor::from_slice(&averaged, grad.dims())
235 .map_err(|e| format!("Failed to create averaged gradient: {:?}", e))?;
236
237 self.partitioned_grads.insert(name, averaged_grad);
238 }
239 }
240
241 Ok(())
242 }
243
244 pub fn all_gather_parameters(&self) -> Result<HashMap<String, Tensor>, String> {
246 if self.config.stage != ZeRoStage::Stage3 {
247 return Ok(self.partitioned_params.clone());
248 }
249
250 let mut all_params = HashMap::new();
253
254 for (name, param) in &self.partitioned_params {
255 all_params.insert(name.clone(), param.clone());
256 }
257
258 Ok(all_params)
259 }
260
261 pub fn step(&mut self) -> Result<(), String> {
263 for (name, param) in &mut self.partitioned_params {
265 if let Some(grad) = self.partitioned_grads.get(name) {
266 if let Some(states) = self.partitioned_states.get_mut(name) {
267 let beta1 = 0.9;
269 let beta2 = 0.999;
270 let eps = 1e-8;
271
272 let m_data = states.get("momentum").unwrap().data_f32();
274 let v_data = states.get("variance").unwrap().data_f32();
275 let g_data = grad.data_f32();
276 let p_data = param.data_f32();
277
278 let mut new_m = Vec::with_capacity(m_data.len());
279 let mut new_v = Vec::with_capacity(v_data.len());
280 let mut new_p = Vec::with_capacity(p_data.len());
281
282 for i in 0..m_data.len() {
283 let m = beta1 * m_data[i] + (1.0 - beta1) * g_data[i];
284 let v = beta2 * v_data[i] + (1.0 - beta2) * g_data[i] * g_data[i];
285 let p = p_data[i] - self.learning_rate * m / (v.sqrt() + eps);
286
287 new_m.push(m);
288 new_v.push(v);
289 new_p.push(p);
290 }
291
292 let m_dims = states.get("momentum").unwrap().dims().to_vec();
294 let v_dims = states.get("variance").unwrap().dims().to_vec();
295 let p_dims = param.dims().to_vec();
296
297 states.insert("momentum".to_string(), Tensor::from_slice(&new_m, &m_dims)
299 .map_err(|e| format!("Failed to create momentum tensor: {:?}", e))?);
300 states.insert("variance".to_string(), Tensor::from_slice(&new_v, &v_dims)
301 .map_err(|e| format!("Failed to create variance tensor: {:?}", e))?);
302 *param = Tensor::from_slice(&new_p, &p_dims)
303 .map_err(|e| format!("Failed to create param tensor: {:?}", e))?;
304 }
305 }
306 }
307
308 Ok(())
309 }
310
311 pub fn offload_to_cpu(&mut self, name: &str) -> Result<(), String> {
313 if !self.config.cpu_offload {
314 return Ok(());
315 }
316
317 if let Some(param) = self.partitioned_params.get(name) {
318 let data = param.data_f32().to_vec();
319 self.cpu_buffer.insert(name.to_string(), data);
320 }
322
323 Ok(())
324 }
325
326 pub fn load_from_cpu(&mut self, name: &str) -> Result<(), String> {
328 if !self.config.cpu_offload {
329 return Ok(());
330 }
331
332 if let Some(data) = self.cpu_buffer.get(name) {
333 if let Some(partition) = self.param_partitions.iter().find(|p| p.name == name) {
334 let tensor = Tensor::from_slice(data, &partition.shape)
335 .map_err(|e| format!("Failed to load from CPU: {:?}", e))?;
336 self.partitioned_params.insert(name.to_string(), tensor);
337 }
338 }
339
340 Ok(())
341 }
342
343 pub fn memory_savings_ratio(&self) -> f32 {
345 match self.config.stage {
346 ZeRoStage::Stage1 => {
347 let n = self.config.world_size as f32;
351 (n - 1.0) / n * 0.5 }
353 ZeRoStage::Stage2 => {
354 let n = self.config.world_size as f32;
358 (n - 1.0) / n * 0.75 }
360 ZeRoStage::Stage3 => {
361 let n = self.config.world_size as f32;
365 (n - 1.0) / n
366 }
367 }
368 }
369
370 pub fn get_stats(&self) -> ZeRoStats {
372 let total_params: usize = self.partitioned_params.values()
373 .map(|t| t.data_f32().len())
374 .sum();
375
376 let total_grads: usize = self.partitioned_grads.values()
377 .map(|t| t.data_f32().len())
378 .sum();
379
380 ZeRoStats {
381 stage: self.config.stage,
382 world_size: self.config.world_size,
383 rank: self.config.rank,
384 num_partitioned_params: self.partitioned_params.len(),
385 total_param_elements: total_params,
386 total_grad_elements: total_grads,
387 memory_savings: self.memory_savings_ratio(),
388 cpu_offload_enabled: self.config.cpu_offload,
389 }
390 }
391}
392
393#[derive(Debug, Clone)]
395pub struct ZeRoStats {
396 pub stage: ZeRoStage,
397 pub world_size: usize,
398 pub rank: usize,
399 pub num_partitioned_params: usize,
400 pub total_param_elements: usize,
401 pub total_grad_elements: usize,
402 pub memory_savings: f32,
403 pub cpu_offload_enabled: bool,
404}
405
406#[cfg(test)]
407mod tests {
408 use super::*;
409
410 #[test]
411 fn test_zero_config() {
412 let config = ZeRoConfig::default();
413 assert_eq!(config.stage, ZeRoStage::Stage2);
414 assert_eq!(config.world_size, 1);
415
416 let stage3 = ZeRoConfig::stage3(4, 0);
417 assert_eq!(stage3.stage, ZeRoStage::Stage3);
418 assert_eq!(stage3.world_size, 4);
419 }
420
421 #[test]
422 fn test_zero_optimizer_creation() {
423 let config = ZeRoConfig::stage2(4, 0);
424 let optimizer = ZeRoOptimizer::new(config, 0.001);
425
426 let stats = optimizer.get_stats();
427 assert_eq!(stats.world_size, 4);
428 assert_eq!(stats.rank, 0);
429 }
430
431 #[test]
432 fn test_partition_parameters() {
433 let config = ZeRoConfig::stage2(2, 0);
434 let mut optimizer = ZeRoOptimizer::new(config, 0.001);
435
436 let mut params = HashMap::new();
437 params.insert("layer1".to_string(), Tensor::randn(&[10, 10]));
438 params.insert("layer2".to_string(), Tensor::randn(&[20, 20]));
439
440 optimizer.partition_parameters(¶ms).unwrap();
441 assert!(optimizer.param_partitions.len() > 0);
442 }
443
444 #[test]
445 fn test_memory_savings_ratio() {
446 let config1 = ZeRoConfig::stage1(4, 0);
447 let optimizer1 = ZeRoOptimizer::new(config1, 0.001);
448 let savings1 = optimizer1.memory_savings_ratio();
449
450 let config2 = ZeRoConfig::stage2(4, 0);
451 let optimizer2 = ZeRoOptimizer::new(config2, 0.001);
452 let savings2 = optimizer2.memory_savings_ratio();
453
454 let config3 = ZeRoConfig::stage3(4, 0);
455 let optimizer3 = ZeRoOptimizer::new(config3, 0.001);
456 let savings3 = optimizer3.memory_savings_ratio();
457
458 assert!(savings3 > savings2);
460 assert!(savings2 > savings1);
461 }
462
463 #[test]
464 fn test_offload_config() {
465 let config = ZeRoConfig::stage3(4, 0)
466 .with_offload(true, false);
467
468 assert!(config.cpu_offload);
469 assert!(!config.nvme_offload);
470 }
471
472 #[test]
473 fn test_cpu_offload() {
474 let config = ZeRoConfig::stage2(2, 0).with_offload(true, false);
475 let mut optimizer = ZeRoOptimizer::new(config, 0.001);
476
477 let mut params = HashMap::new();
478 params.insert("layer1".to_string(), Tensor::randn(&[5, 5]));
479
480 optimizer.partition_parameters(¶ms).unwrap();
481 optimizer.offload_to_cpu("layer1").unwrap();
482
483 assert!(optimizer.cpu_buffer.contains_key("layer1"));
484 }
485}