#[derive(Debug, Clone)]
pub struct OptimizerShard {
pub rank: usize,
pub world_size: usize,
pub param_start: usize,
pub param_end: usize,
pub total_params: usize,
}
impl OptimizerShard {
pub fn for_rank(rank: usize, world_size: usize, total_params: usize) -> Self {
let shard_size = total_params / world_size;
let remainder = total_params % world_size;
let param_start = if rank < remainder {
rank * (shard_size + 1)
} else {
remainder * (shard_size + 1) + (rank - remainder) * shard_size
};
let param_end =
if rank < remainder { param_start + shard_size + 1 } else { param_start + shard_size };
Self { rank, world_size, param_start, param_end, total_params }
}
pub fn shard_size(&self) -> usize {
self.param_end - self.param_start
}
pub fn owns_param(&self, param_idx: usize) -> bool {
param_idx >= self.param_start && param_idx < self.param_end
}
pub fn memory_savings(&self) -> f64 {
1.0 - (1.0 / self.world_size as f64)
}
pub fn shard_memory_bytes(&self) -> usize {
self.shard_size() * 2 * std::mem::size_of::<f32>()
}
pub fn full_memory_bytes(&self) -> usize {
self.total_params * 2 * std::mem::size_of::<f32>()
}
}
#[derive(Debug, Clone)]
pub struct ZeroShardMap {
pub block_owners: Vec<usize>,
pub lm_head_owner: usize,
pub final_norm_owner: usize,
pub embedding_owner: usize,
pub world_size: usize,
}
impl ZeroShardMap {
pub fn round_robin(num_blocks: usize, world_size: usize) -> Self {
let block_owners: Vec<usize> = (0..num_blocks).map(|i| i % world_size).collect();
Self { block_owners, lm_head_owner: 0, final_norm_owner: 0, embedding_owner: 0, world_size }
}
pub fn contiguous(num_blocks: usize, world_size: usize) -> Self {
let blocks_per_worker = num_blocks / world_size;
let remainder = num_blocks % world_size;
let mut block_owners = Vec::with_capacity(num_blocks);
for rank in 0..world_size {
let count = blocks_per_worker + usize::from(rank < remainder);
for _ in 0..count {
block_owners.push(rank);
}
}
Self { block_owners, lm_head_owner: 0, final_norm_owner: 0, embedding_owner: 0, world_size }
}
pub fn block_owner(&self, block_idx: usize) -> usize {
self.block_owners[block_idx]
}
pub fn rank_owns_block(&self, rank: usize, block_idx: usize) -> bool {
self.block_owners[block_idx] == rank
}
pub fn blocks_for_rank(&self, rank: usize) -> Vec<usize> {
self.block_owners
.iter()
.enumerate()
.filter(|(_, &owner)| owner == rank)
.map(|(i, _)| i)
.collect()
}
pub fn num_blocks_for_rank(&self, rank: usize) -> usize {
self.block_owners.iter().filter(|&&owner| owner == rank).count()
}
pub fn memory_fraction_for_rank(&self, rank: usize) -> f64 {
let owned = self.num_blocks_for_rank(rank) as f64;
let total = self.block_owners.len() as f64;
owned / total
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_optimizer_shard_basic() {
let shard = OptimizerShard::for_rank(0, 4, 100);
assert_eq!(shard.shard_size(), 25);
assert_eq!(shard.param_start, 0);
assert_eq!(shard.param_end, 25);
assert!(shard.owns_param(0));
assert!(shard.owns_param(24));
assert!(!shard.owns_param(25));
}
#[test]
fn test_optimizer_shard_remainder() {
let s0 = OptimizerShard::for_rank(0, 3, 10);
let s1 = OptimizerShard::for_rank(1, 3, 10);
let s2 = OptimizerShard::for_rank(2, 3, 10);
assert_eq!(s0.shard_size(), 4); assert_eq!(s1.shard_size(), 3);
assert_eq!(s2.shard_size(), 3);
assert_eq!(s0.param_start, 0);
assert_eq!(s0.param_end, 4);
assert_eq!(s1.param_start, 4);
assert_eq!(s1.param_end, 7);
assert_eq!(s2.param_start, 7);
assert_eq!(s2.param_end, 10);
}
#[test]
fn test_optimizer_shard_completeness() {
let total = 1_000_003; let world_size = 7;
let mut covered = vec![false; total];
for rank in 0..world_size {
let shard = OptimizerShard::for_rank(rank, world_size, total);
for i in shard.param_start..shard.param_end {
assert!(!covered[i], "param {i} covered by multiple shards");
covered[i] = true;
}
}
assert!(covered.iter().all(|&c| c), "not all params covered");
}
#[test]
fn test_optimizer_shard_memory_savings() {
let shard = OptimizerShard::for_rank(0, 4, 1_000_000);
assert!((shard.memory_savings() - 0.75).abs() < 1e-10);
assert_eq!(shard.full_memory_bytes(), 8_000_000);
assert_eq!(shard.shard_memory_bytes(), 2_000_000);
}
#[test]
fn test_zero_shard_map_round_robin() {
let map = ZeroShardMap::round_robin(24, 4);
assert_eq!(map.block_owner(0), 0);
assert_eq!(map.block_owner(1), 1);
assert_eq!(map.block_owner(2), 2);
assert_eq!(map.block_owner(3), 3);
assert_eq!(map.block_owner(4), 0);
assert_eq!(map.num_blocks_for_rank(0), 6);
assert_eq!(map.num_blocks_for_rank(1), 6);
let blocks = map.blocks_for_rank(0);
assert_eq!(blocks, vec![0, 4, 8, 12, 16, 20]);
}
#[test]
fn test_zero_shard_map_contiguous() {
let map = ZeroShardMap::contiguous(24, 4);
assert_eq!(map.blocks_for_rank(0), vec![0, 1, 2, 3, 4, 5]);
assert_eq!(map.blocks_for_rank(1), vec![6, 7, 8, 9, 10, 11]);
assert_eq!(map.blocks_for_rank(2), vec![12, 13, 14, 15, 16, 17]);
assert_eq!(map.blocks_for_rank(3), vec![18, 19, 20, 21, 22, 23]);
}
#[test]
fn test_zero_shard_map_contiguous_uneven() {
let map = ZeroShardMap::contiguous(10, 3);
assert_eq!(map.num_blocks_for_rank(0), 4);
assert_eq!(map.num_blocks_for_rank(1), 3);
assert_eq!(map.num_blocks_for_rank(2), 3);
let total: usize = (0..3).map(|r| map.num_blocks_for_rank(r)).sum();
assert_eq!(total, 10);
}
#[test]
fn test_zero_shard_map_memory_fraction() {
let map = ZeroShardMap::round_robin(24, 4);
let frac = map.memory_fraction_for_rank(0);
assert!((frac - 0.25).abs() < 1e-10);
}
#[test]
fn test_zero_shard_map_rank_owns_block() {
let map = ZeroShardMap::contiguous(12, 3);
assert!(map.rank_owns_block(0, 0));
assert!(map.rank_owns_block(0, 3));
assert!(!map.rank_owns_block(0, 4));
assert!(map.rank_owns_block(1, 4));
}
#[test]
fn test_zero_shard_350m() {
let map = ZeroShardMap::contiguous(24, 4);
for rank in 0..4 {
assert_eq!(map.num_blocks_for_rank(rank), 6);
let frac = map.memory_fraction_for_rank(rank);
assert!((frac - 0.25).abs() < 1e-10);
}
}
}