use core::fmt;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct DeviceId {
node_id: usize,
rack_id: usize,
local_device_id: usize,
}
impl DeviceId {
pub fn new(node_id: usize, rack_id: usize, local_device_id: usize) -> Self {
Self {
node_id,
rack_id,
local_device_id,
}
}
pub fn simple(local_device_id: usize) -> Self {
Self::new(0, 0, local_device_id)
}
pub fn node_id(&self) -> usize {
self.node_id
}
pub fn rack_id(&self) -> usize {
self.rack_id
}
pub fn local_device_id(&self) -> usize {
self.local_device_id
}
pub fn global_id(&self) -> usize {
self.rack_id * 1000 + self.node_id * 100 + self.local_device_id
}
}
#[derive(Debug, Clone)]
pub struct DeviceGroup {
devices: Vec<DeviceId>,
name: Option<String>,
}
impl DeviceGroup {
pub fn new(device_ids: Vec<usize>) -> Self {
let devices = device_ids.iter().map(|&id| DeviceId::simple(id)).collect();
Self {
devices,
name: None,
}
}
pub fn from_devices(devices: Vec<DeviceId>) -> Self {
Self {
devices,
name: None,
}
}
pub fn with_name(mut self, name: impl Into<String>) -> Self {
self.name = Some(name.into());
self
}
pub fn devices(&self) -> &[DeviceId] {
&self.devices
}
pub fn size(&self) -> usize {
self.devices.len()
}
pub fn contains(&self, device_id: &DeviceId) -> bool {
self.devices.contains(device_id)
}
pub fn rank(&self, device_id: &DeviceId) -> Option<usize> {
self.devices.iter().position(|d| d == device_id)
}
pub fn name(&self) -> Option<&str> {
self.name.as_deref()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ShardingStrategy {
Replicated,
DataParallel,
ModelParallel,
DimSharded(usize),
Pipeline,
Hybrid,
}
impl fmt::Display for ShardingStrategy {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ShardingStrategy::Replicated => write!(f, "Replicated"),
ShardingStrategy::DataParallel => write!(f, "DataParallel"),
ShardingStrategy::ModelParallel => write!(f, "ModelParallel"),
ShardingStrategy::DimSharded(dim) => write!(f, "DimSharded({})", dim),
ShardingStrategy::Pipeline => write!(f, "Pipeline"),
ShardingStrategy::Hybrid => write!(f, "Hybrid"),
}
}
}
#[derive(Debug, Clone)]
pub struct Shard {
device_id: DeviceId,
offset: Vec<usize>,
shape: Vec<usize>,
rank: usize,
}
impl Shard {
pub fn new(device_id: DeviceId, offset: Vec<usize>, shape: Vec<usize>, rank: usize) -> Self {
Self {
device_id,
offset,
shape,
rank,
}
}
pub fn device_id(&self) -> DeviceId {
self.device_id
}
pub fn offset(&self) -> &[usize] {
&self.offset
}
pub fn shape(&self) -> &[usize] {
&self.shape
}
pub fn rank(&self) -> usize {
self.rank
}
pub fn size(&self) -> usize {
self.shape.iter().product()
}
}
#[derive(Debug, Clone)]
pub struct DistributedTensor {
global_shape: Vec<usize>,
strategy: ShardingStrategy,
device_group: DeviceGroup,
shards: Vec<Shard>,
}
impl DistributedTensor {
pub fn new(
global_shape: Vec<usize>,
strategy: ShardingStrategy,
device_group: DeviceGroup,
) -> Self {
let shards = Self::create_shards(&global_shape, strategy, &device_group);
Self {
global_shape,
strategy,
device_group,
shards,
}
}
fn create_shards(
global_shape: &[usize],
strategy: ShardingStrategy,
device_group: &DeviceGroup,
) -> Vec<Shard> {
let num_devices = device_group.size();
let mut shards = Vec::new();
match strategy {
ShardingStrategy::Replicated => {
for (rank, &device_id) in device_group.devices().iter().enumerate() {
shards.push(Shard::new(
device_id,
vec![0; global_shape.len()],
global_shape.to_vec(),
rank,
));
}
}
ShardingStrategy::DataParallel | ShardingStrategy::DimSharded(0) => {
if global_shape.is_empty() {
return shards;
}
let dim0 = global_shape[0];
let chunk_size = (dim0 + num_devices - 1) / num_devices;
for (rank, &device_id) in device_group.devices().iter().enumerate() {
let start = rank * chunk_size;
let end = (start + chunk_size).min(dim0);
if start >= dim0 {
break;
}
let mut offset = vec![0; global_shape.len()];
offset[0] = start;
let mut shape = global_shape.to_vec();
shape[0] = end - start;
shards.push(Shard::new(device_id, offset, shape, rank));
}
}
ShardingStrategy::ModelParallel => {
return Self::create_shards(
global_shape,
ShardingStrategy::DataParallel,
device_group,
);
}
ShardingStrategy::DimSharded(dim) => {
if dim >= global_shape.len() {
return shards;
}
let dim_size = global_shape[dim];
let chunk_size = (dim_size + num_devices - 1) / num_devices;
for (rank, &device_id) in device_group.devices().iter().enumerate() {
let start = rank * chunk_size;
let end = (start + chunk_size).min(dim_size);
if start >= dim_size {
break;
}
let mut offset = vec![0; global_shape.len()];
offset[dim] = start;
let mut shape = global_shape.to_vec();
shape[dim] = end - start;
shards.push(Shard::new(device_id, offset, shape, rank));
}
}
_ => {
return Self::create_shards(
global_shape,
ShardingStrategy::Replicated,
device_group,
);
}
}
shards
}
pub fn global_shape(&self) -> &[usize] {
&self.global_shape
}
pub fn strategy(&self) -> ShardingStrategy {
self.strategy
}
pub fn device_group(&self) -> &DeviceGroup {
&self.device_group
}
pub fn shards(&self) -> &[Shard] {
&self.shards
}
pub fn shard_for_device(&self, device_id: &DeviceId) -> Option<&Shard> {
self.shards.iter().find(|s| &s.device_id == device_id)
}
pub fn total_elements(&self) -> usize {
match self.strategy {
ShardingStrategy::Replicated => {
self.global_shape.iter().product()
}
_ => {
self.shards.iter().map(|s| s.size()).sum()
}
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CollectiveOp {
AllReduce(ReduceOp),
AllGather,
ReduceScatter(ReduceOp),
Broadcast { root: usize },
Scatter { root: usize },
Gather { root: usize },
AllToAll,
Barrier,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ReduceOp {
Sum,
Product,
Min,
Max,
Average,
}
impl fmt::Display for ReduceOp {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ReduceOp::Sum => write!(f, "Sum"),
ReduceOp::Product => write!(f, "Product"),
ReduceOp::Min => write!(f, "Min"),
ReduceOp::Max => write!(f, "Max"),
ReduceOp::Average => write!(f, "Average"),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CommBackend {
NCCL,
Gloo,
MPI,
Custom,
}
#[derive(Debug, Clone)]
pub struct CommunicationDescriptor {
operation: CollectiveOp,
device_group: DeviceGroup,
backend: CommBackend,
async_op: bool,
}
impl CommunicationDescriptor {
pub fn new(operation: CollectiveOp, device_group: DeviceGroup, backend: CommBackend) -> Self {
Self {
operation,
device_group,
backend,
async_op: false,
}
}
pub fn with_async(mut self, async_op: bool) -> Self {
self.async_op = async_op;
self
}
pub fn operation(&self) -> CollectiveOp {
self.operation
}
pub fn device_group(&self) -> &DeviceGroup {
&self.device_group
}
pub fn backend(&self) -> CommBackend {
self.backend
}
pub fn is_async(&self) -> bool {
self.async_op
}
}
#[derive(Debug, Clone)]
pub struct CheckpointMetadata {
id: String,
step: u64,
devices: Vec<DeviceId>,
timestamp: u64,
metadata: Vec<(String, String)>,
}
impl CheckpointMetadata {
pub fn new(id: impl Into<String>, step: u64, devices: Vec<DeviceId>) -> Self {
Self {
id: id.into(),
step,
devices,
timestamp: 0, metadata: Vec::new(),
}
}
pub fn add_metadata(&mut self, key: impl Into<String>, value: impl Into<String>) {
self.metadata.push((key.into(), value.into()));
}
pub fn id(&self) -> &str {
&self.id
}
pub fn step(&self) -> u64 {
self.step
}
pub fn devices(&self) -> &[DeviceId] {
&self.devices
}
pub fn timestamp(&self) -> u64 {
self.timestamp
}
pub fn metadata(&self) -> &[(String, String)] {
&self.metadata
}
}
#[derive(Debug, Clone)]
pub struct DeviceTopology {
devices: Vec<DeviceId>,
num_nodes: usize,
num_racks: usize,
devices_per_node: usize,
}
impl DeviceTopology {
pub fn new(num_racks: usize, num_nodes: usize, devices_per_node: usize) -> Self {
let mut devices = Vec::new();
for rack_id in 0..num_racks {
for node_id in 0..num_nodes {
for device_id in 0..devices_per_node {
devices.push(DeviceId::new(node_id, rack_id, device_id));
}
}
}
Self {
devices,
num_nodes,
num_racks,
devices_per_node,
}
}
pub fn devices(&self) -> &[DeviceId] {
&self.devices
}
pub fn node_devices(&self, node_id: usize) -> Vec<DeviceId> {
self.devices
.iter()
.filter(|d| d.node_id() == node_id)
.copied()
.collect()
}
pub fn rack_devices(&self, rack_id: usize) -> Vec<DeviceId> {
self.devices
.iter()
.filter(|d| d.rack_id() == rack_id)
.copied()
.collect()
}
pub fn total_devices(&self) -> usize {
self.devices.len()
}
pub fn num_nodes(&self) -> usize {
self.num_nodes
}
pub fn num_racks(&self) -> usize {
self.num_racks
}
pub fn devices_per_node(&self) -> usize {
self.devices_per_node
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_device_id() {
let device = DeviceId::new(0, 1, 2);
assert_eq!(device.node_id(), 0);
assert_eq!(device.rack_id(), 1);
assert_eq!(device.local_device_id(), 2);
assert_eq!(device.global_id(), 1002); }
#[test]
fn test_simple_device_id() {
let device = DeviceId::simple(5);
assert_eq!(device.local_device_id(), 5);
assert_eq!(device.node_id(), 0);
assert_eq!(device.rack_id(), 0);
}
#[test]
fn test_device_group() {
let group = DeviceGroup::new(vec![0, 1, 2, 3]);
assert_eq!(group.size(), 4);
assert!(group.contains(&DeviceId::simple(0)));
assert_eq!(group.rank(&DeviceId::simple(2)), Some(2));
}
#[test]
fn test_device_group_with_name() {
let group = DeviceGroup::new(vec![0, 1]).with_name("test_group");
assert_eq!(group.name(), Some("test_group"));
}
#[test]
fn test_sharding_strategy_display() {
assert_eq!(format!("{}", ShardingStrategy::Replicated), "Replicated");
assert_eq!(
format!("{}", ShardingStrategy::DataParallel),
"DataParallel"
);
assert_eq!(
format!("{}", ShardingStrategy::DimSharded(1)),
"DimSharded(1)"
);
}
#[test]
fn test_shard() {
let device = DeviceId::simple(0);
let shard = Shard::new(device, vec![0, 0], vec![10, 20], 0);
assert_eq!(shard.device_id(), device);
assert_eq!(shard.offset(), &[0, 0]);
assert_eq!(shard.shape(), &[10, 20]);
assert_eq!(shard.rank(), 0);
assert_eq!(shard.size(), 200);
}
#[test]
fn test_distributed_tensor_replicated() {
let group = DeviceGroup::new(vec![0, 1, 2, 3]);
let tensor = DistributedTensor::new(vec![100, 50], ShardingStrategy::Replicated, group);
assert_eq!(tensor.global_shape(), &[100, 50]);
assert_eq!(tensor.shards().len(), 4);
assert_eq!(tensor.strategy(), ShardingStrategy::Replicated);
for shard in tensor.shards() {
assert_eq!(shard.shape(), &[100, 50]);
}
}
#[test]
fn test_distributed_tensor_data_parallel() {
let group = DeviceGroup::new(vec![0, 1, 2, 3]);
let tensor = DistributedTensor::new(vec![100, 50], ShardingStrategy::DataParallel, group);
assert_eq!(tensor.shards().len(), 4);
for shard in tensor.shards() {
assert_eq!(shard.shape()[0], 25);
assert_eq!(shard.shape()[1], 50);
}
}
#[test]
fn test_distributed_tensor_dim_sharded() {
let group = DeviceGroup::new(vec![0, 1]);
let tensor =
DistributedTensor::new(vec![10, 20, 30], ShardingStrategy::DimSharded(1), group);
assert_eq!(tensor.shards().len(), 2);
assert_eq!(tensor.shards()[0].shape(), &[10, 10, 30]);
assert_eq!(tensor.shards()[1].shape(), &[10, 10, 30]);
}
#[test]
fn test_shard_for_device() {
let group = DeviceGroup::new(vec![0, 1]);
let tensor = DistributedTensor::new(vec![10, 20], ShardingStrategy::DataParallel, group);
let device = DeviceId::simple(0);
let shard = tensor.shard_for_device(&device);
assert!(shard.is_some());
assert_eq!(
shard.expect("shard_for_device should succeed").device_id(),
device
);
}
#[test]
fn test_collective_operations() {
let _all_reduce = CollectiveOp::AllReduce(ReduceOp::Sum);
let _all_gather = CollectiveOp::AllGather;
let _reduce_scatter = CollectiveOp::ReduceScatter(ReduceOp::Average);
let _broadcast = CollectiveOp::Broadcast { root: 0 };
let _scatter = CollectiveOp::Scatter { root: 0 };
let _gather = CollectiveOp::Gather { root: 0 };
let _all_to_all = CollectiveOp::AllToAll;
let _barrier = CollectiveOp::Barrier;
}
#[test]
fn test_reduce_op_display() {
assert_eq!(format!("{}", ReduceOp::Sum), "Sum");
assert_eq!(format!("{}", ReduceOp::Product), "Product");
assert_eq!(format!("{}", ReduceOp::Min), "Min");
assert_eq!(format!("{}", ReduceOp::Max), "Max");
assert_eq!(format!("{}", ReduceOp::Average), "Average");
}
#[test]
fn test_comm_backend() {
let _nccl = CommBackend::NCCL;
let _gloo = CommBackend::Gloo;
let _mpi = CommBackend::MPI;
let _custom = CommBackend::Custom;
}
#[test]
fn test_communication_descriptor() {
let group = DeviceGroup::new(vec![0, 1, 2, 3]);
let comm_desc = CommunicationDescriptor::new(
CollectiveOp::AllReduce(ReduceOp::Sum),
group.clone(),
CommBackend::NCCL,
)
.with_async(true);
assert_eq!(
comm_desc.operation(),
CollectiveOp::AllReduce(ReduceOp::Sum)
);
assert_eq!(comm_desc.backend(), CommBackend::NCCL);
assert!(comm_desc.is_async());
}
#[test]
fn test_checkpoint_metadata() {
let devices = vec![DeviceId::simple(0), DeviceId::simple(1)];
let mut checkpoint = CheckpointMetadata::new("ckpt_001", 1000, devices);
checkpoint.add_metadata("model", "resnet50");
checkpoint.add_metadata("optimizer", "adam");
assert_eq!(checkpoint.id(), "ckpt_001");
assert_eq!(checkpoint.step(), 1000);
assert_eq!(checkpoint.devices().len(), 2);
assert_eq!(checkpoint.metadata().len(), 2);
}
#[test]
fn test_device_topology() {
let topology = DeviceTopology::new(2, 3, 4); assert_eq!(topology.total_devices(), 24); assert_eq!(topology.num_racks(), 2);
assert_eq!(topology.num_nodes(), 3);
assert_eq!(topology.devices_per_node(), 4);
let node0_devices = topology.node_devices(0);
assert_eq!(node0_devices.len(), 8);
let rack0_devices = topology.rack_devices(0);
assert_eq!(rack0_devices.len(), 12); }
#[test]
fn test_total_elements() {
let group = DeviceGroup::new(vec![0, 1, 2, 3]);
let replicated =
DistributedTensor::new(vec![100, 50], ShardingStrategy::Replicated, group.clone());
assert_eq!(replicated.total_elements(), 5000);
let sharded = DistributedTensor::new(vec![100, 50], ShardingStrategy::DataParallel, group);
assert_eq!(sharded.total_elements(), 5000); }
#[test]
fn test_from_devices() {
let devices = vec![DeviceId::new(0, 0, 1), DeviceId::new(0, 0, 2)];
let group = DeviceGroup::from_devices(devices);
assert_eq!(group.size(), 2);
}
#[test]
fn test_device_not_in_group() {
let group = DeviceGroup::new(vec![0, 1, 2]);
assert!(!group.contains(&DeviceId::simple(5)));
assert_eq!(group.rank(&DeviceId::simple(5)), None);
}
}