pub mod zero_optimizer;
pub mod zero_stage1;
pub mod zero_stage2;
pub mod zero_stage3;
pub mod zero_utils;
pub use zero_optimizer::{ZeROConfig, ZeROOptimizer, ZeROStage};
pub use zero_stage1::ZeROStage1;
pub use zero_stage2::ZeROStage2;
pub use zero_stage3::ZeROStage3;
pub use zero_utils::{
all_gather_gradients, gather_parameters, partition_gradients, partition_parameters,
reduce_scatter_gradients, GradientBuffer, ParameterGroup, ParameterPartition, ZeROState,
};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ZeROImplementationStage {
Stage1,
Stage2,
Stage3,
}
#[derive(Debug, Clone)]
pub struct ZeROMemoryStats {
pub optimizer_memory_saved: usize,
pub gradient_memory_saved: usize,
pub parameter_memory_saved: usize,
pub total_memory_saved: usize,
pub communication_overhead: usize,
}
impl Default for ZeROMemoryStats {
fn default() -> Self {
Self::new()
}
}
impl ZeROMemoryStats {
pub fn new() -> Self {
Self {
optimizer_memory_saved: 0,
gradient_memory_saved: 0,
parameter_memory_saved: 0,
total_memory_saved: 0,
communication_overhead: 0,
}
}
pub fn update_totals(&mut self) {
self.total_memory_saved =
self.optimizer_memory_saved + self.gradient_memory_saved + self.parameter_memory_saved;
}
}