impl Communicator {
pub fn new(world_size: usize, rank: usize) -> Result<Self, ParallelError> {
if rank >= world_size {
return Err(ParallelError::InvalidRank { rank, world_size });
}
Ok(Self {
world_size,
rank,
})
}
pub fn all_reduce(
&self,
tensor: &ParallelTensor,
op: ReduceOp,
) -> Result<ParallelTensor, ParallelError> {
match op {
ReduceOp::Sum => {
let data: Vec<f32> = tensor
.data
.iter()
.map(|x| x * self.world_size as f32)
.collect();
Ok(ParallelTensor {
shape: tensor.shape.clone(),
data,
})
},
ReduceOp::Avg => {
Ok(tensor.clone())
},
ReduceOp::Max | ReduceOp::Min => {
Ok(tensor.clone())
},
}
}
pub fn all_gather(&self, tensor: &ParallelTensor) -> Result<ParallelTensor, ParallelError> {
let mut data = Vec::with_capacity(tensor.data.len() * self.world_size);
for _ in 0..self.world_size {
data.extend_from_slice(&tensor.data);
}
let mut new_shape = tensor.shape.clone();
if !new_shape.is_empty() {
new_shape[0] *= self.world_size;
}
Ok(ParallelTensor {
shape: new_shape,
data,
})
}
pub fn reduce_scatter(
&self,
tensor: &ParallelTensor,
op: ReduceOp,
) -> Result<ParallelTensor, ParallelError> {
let chunk_size = tensor.data.len() / self.world_size;
let start = self.rank * chunk_size;
let end = start + chunk_size;
let chunk_data: Vec<f32> = match op {
ReduceOp::Sum => tensor.data[start..end]
.iter()
.map(|x| x * self.world_size as f32)
.collect(),
ReduceOp::Avg | ReduceOp::Max | ReduceOp::Min => tensor.data[start..end].to_vec(),
};
let mut new_shape = tensor.shape.clone();
if !new_shape.is_empty() {
new_shape[0] /= self.world_size;
}
Ok(ParallelTensor {
shape: new_shape,
data: chunk_data,
})
}
pub fn barrier(&self) -> Result<(), ParallelError> {
Ok(())
}
pub fn world_size(&self) -> usize {
self.world_size
}
pub fn rank(&self) -> usize {
self.rank
}
}
#[derive(Debug)]
pub struct TensorParallel {
tp_size: usize,
rank: usize,
comm: Communicator,
}
impl TensorParallel {
pub fn new(tp_size: usize, rank: usize) -> Result<Self, ParallelError> {
if tp_size == 0 {
return Err(ParallelError::InvalidWorldSize(0));
}
if rank >= tp_size {
return Err(ParallelError::InvalidRank {
rank,
world_size: tp_size,
});
}
let comm = Communicator::new(tp_size, rank)?;
Ok(Self {
tp_size,
rank,
comm,
})
}
pub fn chunk_size(&self, total_size: usize) -> usize {
total_size / self.tp_size
}
pub fn column_parallel_linear(
&self,
input: &ParallelTensor,
weight: &ParallelTensor,
bias: Option<&ParallelTensor>,
) -> Result<ParallelTensor, ParallelError> {
let output_dim = weight.shape[0];
let chunk = self.chunk_size(output_dim);
let local_weight = weight.narrow(0, self.rank * chunk, chunk)?;
let weight_t = local_weight.transpose()?;
let mut local_output = input.matmul(&weight_t)?;
if let Some(b) = bias {
let local_bias = b.narrow(0, self.rank * chunk, chunk)?;
let bias_expanded = ParallelTensor {
shape: local_output.shape.clone(),
data: local_output
.data
.iter()
.enumerate()
.map(|(i, v)| v + local_bias.data[i % local_bias.data.len()])
.collect(),
};
local_output = bias_expanded;
}
Ok(local_output)
}
pub fn row_parallel_linear(
&self,
input: &ParallelTensor,
weight: &ParallelTensor,
bias: Option<&ParallelTensor>,
) -> Result<ParallelTensor, ParallelError> {
let input_dim = weight.shape[0];
let chunk = self.chunk_size(input_dim);
let local_weight = weight.narrow(0, self.rank * chunk, chunk)?;
let weight_t = local_weight.transpose()?;
let local_output = input.matmul(&weight_t)?;
let mut output = self.comm.all_reduce(&local_output, ReduceOp::Sum)?;
if self.rank == 0 {
if let Some(b) = bias {
output = output.add(b)?;
}
}
Ok(output)
}
pub fn rank(&self) -> usize {
self.rank
}
pub fn tp_size(&self) -> usize {
self.tp_size
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PipelineStage {
pub index: usize,
pub start_layer: usize,
pub end_layer: usize,
pub num_layers: usize,
}
#[derive(Debug)]
pub struct PipelineParallel {
pp_size: usize,
stage: usize,
stage_info: PipelineStage,
micro_batch_size: usize,
stats: PipelineStats,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct PipelineStats {
pub micro_batches_processed: u64,
pub bubble_time_ms: f64,
pub avg_stage_latency_ms: f64,
pub forward_passes: u64,
}
impl PipelineParallel {
pub fn new(
pp_size: usize,
stage: usize,
total_layers: usize,
micro_batch_size: usize,
) -> Result<Self, ParallelError> {
if pp_size == 0 {
return Err(ParallelError::InvalidWorldSize(0));
}
if stage >= pp_size {
return Err(ParallelError::InvalidRank {
rank: stage,
world_size: pp_size,
});
}
let layers_per_stage = total_layers / pp_size;
let extra_layers = total_layers % pp_size;
let start_layer = stage * layers_per_stage + stage.min(extra_layers);
let num_layers = layers_per_stage + usize::from(stage < extra_layers);
let end_layer = start_layer + num_layers;
let stage_info = PipelineStage {
index: stage,
start_layer,
end_layer,
num_layers,
};
Ok(Self {
pp_size,
stage,
stage_info,
micro_batch_size,
stats: PipelineStats::default(),
})
}
pub fn stage_info(&self) -> &PipelineStage {
&self.stage_info
}
pub fn micro_batch_size(&self) -> usize {
self.micro_batch_size
}
pub fn is_first_stage(&self) -> bool {
self.stage == 0
}
pub fn is_last_stage(&self) -> bool {
self.stage == self.pp_size - 1
}
pub fn num_stages(&self) -> usize {
self.pp_size
}
pub fn stage(&self) -> usize {
self.stage
}
pub fn bubble_ratio(&self, num_microbatches: usize) -> f32 {
if num_microbatches == 0 {
return 1.0;
}
(self.pp_size - 1) as f32 / (self.pp_size + num_microbatches - 1) as f32
}
pub fn stats(&self) -> &PipelineStats {
&self.stats
}
pub fn record_micro_batch(&mut self, stage_latency_ms: f64) {
self.stats.micro_batches_processed += 1;
self.stats.forward_passes += 1;
let n = self.stats.micro_batches_processed as f64;
self.stats.avg_stage_latency_ms =
(self.stats.avg_stage_latency_ms * (n - 1.0) + stage_latency_ms) / n;
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[allow(clippy::struct_excessive_bools)] pub struct ZeroOffload {
pub offload_optimizer: bool,
pub offload_params: bool,
pub offload_activations: bool,
pub pin_memory: bool,
pub overlap_comm: bool,
}
impl Default for ZeroOffload {
fn default() -> Self {
Self {
offload_optimizer: true,
offload_params: false,
offload_activations: false,
pin_memory: true,
overlap_comm: true,
}
}
}
impl ZeroOffload {
pub fn inference() -> Self {
Self {
offload_optimizer: false, offload_params: true,
offload_activations: true,
pin_memory: true,
overlap_comm: true,
}
}
pub fn memory_savings_ratio(&self) -> f32 {
let mut ratio = 1.0;
if self.offload_params {
ratio *= 0.5; }
if self.offload_activations {
ratio *= 0.7; }
1.0 - ratio
}
}