#[derive(Debug)]
pub struct DistributedContext {
config: ParallelConfig,
tensor_parallel: Option<TensorParallel>,
pipeline_parallel: Option<PipelineParallel>,
zero_offload: ZeroOffload,
initialized: bool,
}
impl DistributedContext {
pub fn new(config: ParallelConfig) -> Result<Self, ParallelError> {
let tensor_parallel = if config.tp_size > 1 {
Some(TensorParallel::new(config.tp_size, config.tp_rank())?)
} else {
None
};
let pipeline_parallel = None;
Ok(Self {
config,
tensor_parallel,
pipeline_parallel,
zero_offload: ZeroOffload::default(),
initialized: true,
})
}
pub fn init_pipeline(
&mut self,
total_layers: usize,
micro_batch_size: usize,
) -> Result<(), ParallelError> {
if self.config.pp_size > 1 {
self.pipeline_parallel = Some(PipelineParallel::new(
self.config.pp_size,
self.config.pp_stage(),
total_layers,
micro_batch_size,
)?);
}
Ok(())
}
pub fn set_zero_offload(&mut self, zero: ZeroOffload) {
self.zero_offload = zero;
}
pub fn config(&self) -> &ParallelConfig {
&self.config
}
pub fn tensor_parallel(&self) -> Option<&TensorParallel> {
self.tensor_parallel.as_ref()
}
pub fn pipeline_parallel(&self) -> Option<&PipelineParallel> {
self.pipeline_parallel.as_ref()
}
pub fn pipeline_parallel_mut(&mut self) -> Option<&mut PipelineParallel> {
self.pipeline_parallel.as_mut()
}
pub fn zero_offload(&self) -> &ZeroOffload {
&self.zero_offload
}
pub fn is_distributed(&self) -> bool {
self.config.world_size > 1
}
pub fn is_initialized(&self) -> bool {
self.initialized
}
}
#[cfg(test)]
mod tests;