#![allow(clippy::must_use_candidate)]
#![allow(clippy::return_self_not_must_use)]
#![allow(clippy::missing_errors_doc)]
use serde::{Deserialize, Serialize};
use thiserror::Error;
#[derive(Debug, Error)]
pub enum ParallelError {
#[error("Invalid rank {rank} for world size {world_size}")]
InvalidRank {
rank: usize,
world_size: usize,
},
#[error("Invalid world size: {0}")]
InvalidWorldSize(usize),
#[error("Communication error: {0}")]
CommunicationError(String),
#[error("Tensor shape mismatch: expected {expected:?}, got {got:?}")]
ShapeMismatch {
expected: Vec<usize>,
got: Vec<usize>,
},
#[error("Pipeline stage error: {0}")]
PipelineError(String),
#[error("Parallel context not initialized")]
NotInitialized,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum ReduceOp {
Sum,
Max,
Min,
Avg,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ParallelConfig {
pub tp_size: usize,
pub pp_size: usize,
pub dp_size: usize,
pub rank: usize,
pub world_size: usize,
}
impl ParallelConfig {
pub fn new(
tp_size: usize,
pp_size: usize,
dp_size: usize,
rank: usize,
) -> Result<Self, ParallelError> {
let world_size = tp_size * pp_size * dp_size;
if world_size == 0 {
return Err(ParallelError::InvalidWorldSize(0));
}
if rank >= world_size {
return Err(ParallelError::InvalidRank { rank, world_size });
}
Ok(Self {
tp_size,
pp_size,
dp_size,
rank,
world_size,
})
}
pub fn single() -> Self {
Self {
tp_size: 1,
pp_size: 1,
dp_size: 1,
rank: 0,
world_size: 1,
}
}
pub fn tp_rank(&self) -> usize {
self.rank % self.tp_size
}
pub fn pp_stage(&self) -> usize {
(self.rank / self.tp_size) % self.pp_size
}
pub fn dp_rank(&self) -> usize {
self.rank / (self.tp_size * self.pp_size)
}
pub fn is_tp_first(&self) -> bool {
self.tp_rank() == 0
}
pub fn is_tp_last(&self) -> bool {
self.tp_rank() == self.tp_size - 1
}
pub fn is_pp_first(&self) -> bool {
self.pp_stage() == 0
}
pub fn is_pp_last(&self) -> bool {
self.pp_stage() == self.pp_size - 1
}
}
impl Default for ParallelConfig {
fn default() -> Self {
Self::single()
}
}
#[derive(Debug, Clone)]
pub struct ParallelTensor {
pub shape: Vec<usize>,
pub data: Vec<f32>,
}
impl ParallelTensor {
pub fn new(shape: Vec<usize>, data: Vec<f32>) -> Result<Self, ParallelError> {
let expected_size: usize = shape.iter().product();
if data.len() != expected_size {
return Err(ParallelError::ShapeMismatch {
expected: vec![expected_size],
got: vec![data.len()],
});
}
Ok(Self { shape, data })
}
pub fn zeros(shape: Vec<usize>) -> Self {
let size: usize = shape.iter().product();
Self {
shape,
data: vec![0.0; size],
}
}
pub fn narrow(&self, dim: usize, start: usize, length: usize) -> Result<Self, ParallelError> {
if dim >= self.shape.len() {
return Err(ParallelError::ShapeMismatch {
expected: vec![dim],
got: self.shape.clone(),
});
}
if self.shape.len() == 2 {
let rows = self.shape[0];
let cols = self.shape[1];
if dim == 0 {
let mut new_data = Vec::with_capacity(length * cols);
for row in start..(start + length) {
let row_start = row * cols;
new_data.extend_from_slice(&self.data[row_start..row_start + cols]);
}
let new_shape = vec![length, cols];
return Ok(Self {
shape: new_shape,
data: new_data,
});
}
let mut new_data = Vec::with_capacity(rows * length);
for row in 0..rows {
let row_start = row * cols;
new_data
.extend_from_slice(&self.data[row_start + start..row_start + start + length]);
}
let new_shape = vec![rows, length];
return Ok(Self {
shape: new_shape,
data: new_data,
});
}
if self.shape.len() == 1 {
let new_data = self.data[start..start + length].to_vec();
return Ok(Self {
shape: vec![length],
data: new_data,
});
}
let new_data = self.data[start..start + length].to_vec();
let mut new_shape = self.shape.clone();
new_shape[dim] = length;
Ok(Self {
shape: new_shape,
data: new_data,
})
}
pub fn transpose(&self) -> Result<Self, ParallelError> {
if self.shape.len() != 2 {
return Err(ParallelError::ShapeMismatch {
expected: vec![2],
got: vec![self.shape.len()],
});
}
let rows = self.shape[0];
let cols = self.shape[1];
let mut new_data = vec![0.0; rows * cols];
for i in 0..rows {
for j in 0..cols {
new_data[j * rows + i] = self.data[i * cols + j];
}
}
Ok(Self {
shape: vec![cols, rows],
data: new_data,
})
}
pub fn matmul(&self, other: &Self) -> Result<Self, ParallelError> {
if self.shape.len() != 2 || other.shape.len() != 2 {
return Err(ParallelError::ShapeMismatch {
expected: vec![2, 2],
got: vec![self.shape.len(), other.shape.len()],
});
}
let m = self.shape[0];
let k = self.shape[1];
let n = other.shape[1];
if k != other.shape[0] {
return Err(ParallelError::ShapeMismatch {
expected: vec![k],
got: vec![other.shape[0]],
});
}
let mut result = vec![0.0; m * n];
for i in 0..m {
for j in 0..n {
let mut sum = 0.0;
for l in 0..k {
sum += self.data[i * k + l] * other.data[l * n + j];
}
result[i * n + j] = sum;
}
}
Ok(Self {
shape: vec![m, n],
data: result,
})
}
pub fn add(&self, other: &Self) -> Result<Self, ParallelError> {
if self.shape != other.shape {
return Err(ParallelError::ShapeMismatch {
expected: self.shape.clone(),
got: other.shape.clone(),
});
}
let data: Vec<f32> = self
.data
.iter()
.zip(&other.data)
.map(|(a, b)| a + b)
.collect();
Ok(Self {
shape: self.shape.clone(),
data,
})
}
pub fn sum(&self) -> f32 {
self.data.iter().sum()
}
pub fn numel(&self) -> usize {
self.data.len()
}
}
#[derive(Debug, Clone)]
pub struct Communicator {
world_size: usize,
rank: usize,
}
include!("mod_all_reduce_communicator.rs");
include!("distributed_context_impl.rs");