use super::communication::{
AsyncCommunicator, CommunicationError, CompressionStrategy, MessagePriority, TensorMessage,
};
use super::coordinator::CoordinatorError;
use super::process::{Communicator, ProcessError};
use crate::error::NumRs2Error;
use scirs2_core::ndarray::{Array1, Array2};
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, VecDeque};
use std::sync::Arc;
use thiserror::Error;
use tokio::sync::{Mutex, RwLock};
#[derive(Error, Debug)]
pub enum ModelParallelError {
#[error("Process error: {0}")]
Process(#[from] ProcessError),
#[error("Communication error: {0}")]
Communication(#[from] CommunicationError),
#[error("Coordinator error: {0}")]
Coordinator(#[from] CoordinatorError),
#[error("Invalid stage assignment: stage {stage} out of {total}")]
InvalidStage { stage: usize, total: usize },
#[error("Partition error: {0}")]
PartitionError(String),
#[error("Pipeline error: {0}")]
PipelineError(String),
#[error("Checkpoint error: {0}")]
CheckpointError(String),
#[error("Invalid microbatch: {0}")]
InvalidMicrobatch(String),
}
impl From<ModelParallelError> for NumRs2Error {
fn from(err: ModelParallelError) -> Self {
NumRs2Error::DistributedComputing(err.to_string())
}
}
#[derive(
Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, oxicode::Encode, oxicode::Decode,
)]
pub enum PartitionStrategy {
ColumnWise,
RowWise,
BatchWise,
SequenceWise,
}
#[derive(Debug, Clone)]
pub struct PipelineStage {
pub stage_id: usize,
pub num_stages: usize,
pub ranks: Vec<usize>,
pub prev_stage: Option<usize>,
pub next_stage: Option<usize>,
}
impl PipelineStage {
pub fn new(stage_id: usize, num_stages: usize, ranks: Vec<usize>) -> Self {
let prev_stage = if stage_id > 0 {
Some(stage_id - 1)
} else {
None
};
let next_stage = if stage_id < num_stages - 1 {
Some(stage_id + 1)
} else {
None
};
Self {
stage_id,
num_stages,
ranks,
prev_stage,
next_stage,
}
}
pub fn is_first(&self) -> bool {
self.stage_id == 0
}
pub fn is_last(&self) -> bool {
self.stage_id == self.num_stages - 1
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Microbatch<T> {
pub id: usize,
pub data: Vec<T>,
pub shape: Vec<usize>,
pub current_stage: usize,
}
impl<T: Clone> Microbatch<T> {
pub fn new(id: usize, data: Vec<T>, shape: Vec<usize>) -> Self {
Self {
id,
data,
shape,
current_stage: 0,
}
}
pub fn advance_stage(&mut self) {
self.current_stage += 1;
}
}
pub struct PipelineParallel {
communicator: Arc<Communicator>,
async_comm: AsyncCommunicator,
stage: PipelineStage,
num_microbatches: usize,
forward_buffer: Arc<Mutex<HashMap<usize, Vec<f32>>>>,
backward_buffer: Arc<Mutex<HashMap<usize, Vec<f32>>>>,
schedule: PipelineSchedule,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum PipelineSchedule {
GPipe,
OneFOneB,
}
impl PipelineParallel {
pub fn new(
communicator: Arc<Communicator>,
num_stages: usize,
num_microbatches: usize,
) -> Result<Self, ModelParallelError> {
let async_comm = AsyncCommunicator::new(communicator.clone())?;
let rank = communicator.rank();
let world_size = communicator.size();
let ranks_per_stage = world_size.div_ceil(num_stages);
let stage_id = rank / ranks_per_stage;
if stage_id >= num_stages {
return Err(ModelParallelError::InvalidStage {
stage: stage_id,
total: num_stages,
});
}
let stage_start = stage_id * ranks_per_stage;
let stage_end = (stage_start + ranks_per_stage).min(world_size);
let ranks: Vec<usize> = (stage_start..stage_end).collect();
let stage = PipelineStage::new(stage_id, num_stages, ranks);
Ok(Self {
communicator,
async_comm,
stage,
num_microbatches,
forward_buffer: Arc::new(Mutex::new(HashMap::new())),
backward_buffer: Arc::new(Mutex::new(HashMap::new())),
schedule: PipelineSchedule::GPipe,
})
}
pub async fn send_forward(
&self,
microbatch_id: usize,
activations: &[f32],
) -> Result<(), ModelParallelError> {
if let Some(next_stage) = self.stage.next_stage {
let next_rank = next_stage * self.communicator.size().div_ceil(self.stage.num_stages);
let msg = TensorMessage::new(
activations.to_vec(),
CompressionStrategy::None,
MessagePriority::High,
)
.with_tag(microbatch_id as u32);
self.async_comm.isend(msg, next_rank).await?;
}
Ok(())
}
pub async fn recv_forward(&self, microbatch_id: usize) -> Result<Vec<f32>, ModelParallelError> {
if let Some(prev_stage) = self.stage.prev_stage {
let prev_rank = prev_stage * self.communicator.size().div_ceil(self.stage.num_stages);
let mut buffer = self.forward_buffer.lock().await;
if let Some(data) = buffer.remove(µbatch_id) {
return Ok(data);
}
drop(buffer);
let _ = prev_rank;
Ok(vec![0.0; 10]) } else {
Err(ModelParallelError::PipelineError(
"No previous stage to receive from".to_string(),
))
}
}
pub async fn send_backward(
&self,
microbatch_id: usize,
gradients: &[f32],
) -> Result<(), ModelParallelError> {
if let Some(prev_stage) = self.stage.prev_stage {
let prev_rank = prev_stage * self.communicator.size().div_ceil(self.stage.num_stages);
let msg = TensorMessage::new(
gradients.to_vec(),
CompressionStrategy::None,
MessagePriority::High,
)
.with_tag(microbatch_id as u32);
self.async_comm.isend(msg, prev_rank).await?;
}
Ok(())
}
pub async fn recv_backward(
&self,
microbatch_id: usize,
) -> Result<Vec<f32>, ModelParallelError> {
if let Some(next_stage) = self.stage.next_stage {
let next_rank = next_stage * self.communicator.size().div_ceil(self.stage.num_stages);
let mut buffer = self.backward_buffer.lock().await;
if let Some(data) = buffer.remove(µbatch_id) {
return Ok(data);
}
drop(buffer);
let _ = next_rank;
Ok(vec![0.0; 10]) } else {
Err(ModelParallelError::PipelineError(
"No next stage to receive from".to_string(),
))
}
}
pub fn stage(&self) -> &PipelineStage {
&self.stage
}
pub fn num_microbatches(&self) -> usize {
self.num_microbatches
}
}
pub struct TensorParallel {
communicator: Arc<Communicator>,
async_comm: AsyncCommunicator,
strategy: PartitionStrategy,
tp_size: usize,
tp_rank: usize,
}
impl TensorParallel {
pub fn new(
communicator: Arc<Communicator>,
strategy: PartitionStrategy,
) -> Result<Self, ModelParallelError> {
let async_comm = AsyncCommunicator::new(communicator.clone())?;
let tp_size = communicator.size();
let tp_rank = communicator.rank();
Ok(Self {
communicator,
async_comm,
strategy,
tp_size,
tp_rank,
})
}
pub fn partition(
&self,
tensor: &[f32],
shape: &[usize],
) -> Result<Vec<f32>, ModelParallelError> {
match self.strategy {
PartitionStrategy::ColumnWise => {
if shape.len() != 2 {
return Err(ModelParallelError::PartitionError(
"ColumnWise partition requires 2D tensor".to_string(),
));
}
let cols = shape[1];
let cols_per_rank = cols.div_ceil(self.tp_size);
let start_col = self.tp_rank * cols_per_rank;
let end_col = (start_col + cols_per_rank).min(cols);
let mut partition = Vec::new();
for row in 0..shape[0] {
for col in start_col..end_col {
let idx = row * cols + col;
if idx < tensor.len() {
partition.push(tensor[idx]);
}
}
}
Ok(partition)
}
PartitionStrategy::RowWise => {
if shape.len() != 2 {
return Err(ModelParallelError::PartitionError(
"RowWise partition requires 2D tensor".to_string(),
));
}
let rows = shape[0];
let rows_per_rank = rows.div_ceil(self.tp_size);
let start_row = self.tp_rank * rows_per_rank;
let end_row = (start_row + rows_per_rank).min(rows);
let cols = shape[1];
let mut partition = Vec::new();
for row in start_row..end_row {
for col in 0..cols {
let idx = row * cols + col;
if idx < tensor.len() {
partition.push(tensor[idx]);
}
}
}
Ok(partition)
}
PartitionStrategy::BatchWise | PartitionStrategy::SequenceWise => {
let chunk_size = tensor.len().div_ceil(self.tp_size);
let start = self.tp_rank * chunk_size;
let end = (start + chunk_size).min(tensor.len());
Ok(tensor[start..end].to_vec())
}
}
}
pub async fn gather(&self, local_tensor: &[f32]) -> Result<Vec<f32>, ModelParallelError> {
Ok(local_tensor.to_vec())
}
pub fn strategy(&self) -> PartitionStrategy {
self.strategy
}
pub fn tp_size(&self) -> usize {
self.tp_size
}
pub fn tp_rank(&self) -> usize {
self.tp_rank
}
}
pub struct ActivationCheckpointer {
interval: usize,
checkpoints: Arc<RwLock<HashMap<usize, Vec<f32>>>>,
recomputation_count: Arc<Mutex<usize>>,
}
impl ActivationCheckpointer {
pub fn new(interval: usize) -> Result<Self, ModelParallelError> {
Ok(Self {
interval,
checkpoints: Arc::new(RwLock::new(HashMap::new())),
recomputation_count: Arc::new(Mutex::new(0)),
})
}
pub fn should_checkpoint(&self, layer_id: usize) -> bool {
layer_id.is_multiple_of(self.interval)
}
pub async fn checkpoint(&self, layer_id: usize, activations: Vec<f32>) {
let mut checkpoints = self.checkpoints.write().await;
checkpoints.insert(layer_id, activations);
}
pub async fn get_checkpoint(&self, layer_id: usize) -> Option<Vec<f32>> {
let checkpoints = self.checkpoints.read().await;
checkpoints.get(&layer_id).cloned()
}
pub async fn clear(&self) {
let mut checkpoints = self.checkpoints.write().await;
checkpoints.clear();
}
pub async fn recomputation_count(&self) -> usize {
*self.recomputation_count.lock().await
}
pub async fn increment_recomputation(&self) {
let mut count = self.recomputation_count.lock().await;
*count += 1;
}
pub fn interval(&self) -> usize {
self.interval
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_partition_strategy_serialization() {
let strategies = vec![
PartitionStrategy::ColumnWise,
PartitionStrategy::RowWise,
PartitionStrategy::BatchWise,
PartitionStrategy::SequenceWise,
];
for strategy in strategies {
let serialized = oxicode::encode_to_vec(&strategy);
assert!(serialized.is_ok());
let bytes = serialized.expect("serialization failed");
let result = oxicode::decode_from_slice::<PartitionStrategy>(&bytes);
assert!(result.is_ok());
let (deserialized, _) = result.expect("deserialization failed");
assert_eq!(
std::mem::discriminant(&strategy),
std::mem::discriminant(&deserialized)
);
}
}
#[test]
fn test_pipeline_stage_creation() {
let stage = PipelineStage::new(1, 4, vec![2, 3]);
assert_eq!(stage.stage_id, 1);
assert_eq!(stage.num_stages, 4);
assert_eq!(stage.ranks, vec![2, 3]);
assert_eq!(stage.prev_stage, Some(0));
assert_eq!(stage.next_stage, Some(2));
assert!(!stage.is_first());
assert!(!stage.is_last());
}
#[test]
fn test_pipeline_stage_first() {
let stage = PipelineStage::new(0, 4, vec![0]);
assert!(stage.is_first());
assert!(!stage.is_last());
assert_eq!(stage.prev_stage, None);
assert_eq!(stage.next_stage, Some(1));
}
#[test]
fn test_pipeline_stage_last() {
let stage = PipelineStage::new(3, 4, vec![6, 7]);
assert!(!stage.is_first());
assert!(stage.is_last());
assert_eq!(stage.prev_stage, Some(2));
assert_eq!(stage.next_stage, None);
}
#[test]
fn test_microbatch_creation() {
let data = vec![1.0, 2.0, 3.0, 4.0];
let shape = vec![2, 2];
let mb = Microbatch::new(0, data.clone(), shape.clone());
assert_eq!(mb.id, 0);
assert_eq!(mb.data, data);
assert_eq!(mb.shape, shape);
assert_eq!(mb.current_stage, 0);
}
#[test]
fn test_microbatch_advance() {
let mut mb = Microbatch::new(0, vec![1.0], vec![1]);
assert_eq!(mb.current_stage, 0);
mb.advance_stage();
assert_eq!(mb.current_stage, 1);
mb.advance_stage();
assert_eq!(mb.current_stage, 2);
}
#[test]
fn test_activation_checkpointer_should_checkpoint() {
let checkpointer = ActivationCheckpointer::new(2).expect("checkpointer creation failed");
assert!(checkpointer.should_checkpoint(0));
assert!(!checkpointer.should_checkpoint(1));
assert!(checkpointer.should_checkpoint(2));
assert!(!checkpointer.should_checkpoint(3));
assert!(checkpointer.should_checkpoint(4));
}
#[tokio::test]
async fn test_activation_checkpointer_store_retrieve() {
let checkpointer = ActivationCheckpointer::new(1).expect("checkpointer creation failed");
let activations = vec![1.0, 2.0, 3.0];
checkpointer.checkpoint(0, activations.clone()).await;
let retrieved = checkpointer.get_checkpoint(0).await;
assert!(retrieved.is_some());
assert_eq!(retrieved.expect("checkpoint retrieval failed"), activations);
}
#[tokio::test]
async fn test_activation_checkpointer_clear() {
let checkpointer = ActivationCheckpointer::new(1).expect("checkpointer creation failed");
checkpointer.checkpoint(0, vec![1.0, 2.0]).await;
checkpointer.checkpoint(1, vec![3.0, 4.0]).await;
checkpointer.clear().await;
assert_eq!(checkpointer.get_checkpoint(0).await, None);
assert_eq!(checkpointer.get_checkpoint(1).await, None);
}
#[tokio::test]
async fn test_activation_checkpointer_recomputation_count() {
let checkpointer = ActivationCheckpointer::new(2).expect("checkpointer creation failed");
assert_eq!(checkpointer.recomputation_count().await, 0);
checkpointer.increment_recomputation().await;
assert_eq!(checkpointer.recomputation_count().await, 1);
checkpointer.increment_recomputation().await;
assert_eq!(checkpointer.recomputation_count().await, 2);
}
#[test]
fn test_activation_checkpointer_interval() {
let checkpointer = ActivationCheckpointer::new(3).expect("checkpointer creation failed");
assert_eq!(checkpointer.interval(), 3);
}
#[test]
fn test_partition_strategy_equality() {
assert_eq!(PartitionStrategy::ColumnWise, PartitionStrategy::ColumnWise);
assert_ne!(PartitionStrategy::ColumnWise, PartitionStrategy::RowWise);
}
}