use rayon::prelude::*;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct DeviceId(pub usize);
#[derive(Debug, Clone)]
pub struct DeviceInfo {
pub id: DeviceId,
pub memory_bytes: usize,
pub compute_units: usize,
pub name: String,
}
impl DeviceInfo {
fn simulated(linear_id: usize) -> Self {
Self {
id: DeviceId(linear_id),
memory_bytes: 24 * 1024 * 1024 * 1024,
compute_units: 108,
name: format!("SimGPU-{linear_id}"),
}
}
}
pub struct DeviceMesh {
devices: Vec<DeviceInfo>,
tp_size: usize,
pp_size: usize,
}
impl DeviceMesh {
pub fn tensor_parallel(n: usize) -> Self {
Self::new(n, 1)
}
pub fn new(tp_size: usize, pp_size: usize) -> Self {
let total = tp_size * pp_size;
let devices = (0..total).map(DeviceInfo::simulated).collect();
Self {
devices,
tp_size,
pp_size,
}
}
pub fn size(&self) -> usize {
self.devices.len()
}
pub fn get(&self, tp_rank: usize, pp_rank: usize) -> Option<&DeviceInfo> {
if tp_rank >= self.tp_size || pp_rank >= self.pp_size {
return None;
}
let idx = tp_rank + pp_rank * self.tp_size;
self.devices.get(idx)
}
pub fn tp_group(&self, pp_rank: usize) -> Vec<&DeviceInfo> {
if pp_rank >= self.pp_size {
return Vec::new();
}
(0..self.tp_size)
.filter_map(|tp| self.get(tp, pp_rank))
.collect()
}
pub fn pp_group(&self, tp_rank: usize) -> Vec<&DeviceInfo> {
if tp_rank >= self.tp_size {
return Vec::new();
}
(0..self.pp_size)
.filter_map(|pp| self.get(tp_rank, pp))
.collect()
}
}
#[derive(Debug, Clone)]
pub struct CollectiveResult {
pub data: Vec<f32>,
pub participating_devices: usize,
pub op_name: &'static str,
}
pub struct NcclCollectives;
impl NcclCollectives {
pub fn all_reduce_sum(shards: &[Vec<f32>]) -> CollectiveResult {
let n = shards.first().map(|s| s.len()).unwrap_or(0);
let data: Vec<f32> = (0..n)
.into_par_iter()
.map(|i| shards.iter().map(|s| s[i]).sum::<f32>())
.collect();
CollectiveResult {
data,
participating_devices: shards.len(),
op_name: "all_reduce_sum",
}
}
pub fn all_reduce_max(shards: &[Vec<f32>]) -> CollectiveResult {
let n = shards.first().map(|s| s.len()).unwrap_or(0);
let data: Vec<f32> = (0..n)
.into_par_iter()
.map(|i| {
shards
.iter()
.map(|s| s[i])
.fold(f32::NEG_INFINITY, f32::max)
})
.collect();
CollectiveResult {
data,
participating_devices: shards.len(),
op_name: "all_reduce_max",
}
}
pub fn all_gather(shards: &[Vec<f32>]) -> CollectiveResult {
let data: Vec<f32> = shards.iter().flat_map(|s| s.iter().copied()).collect();
CollectiveResult {
data,
participating_devices: shards.len(),
op_name: "all_gather",
}
}
pub fn reduce_scatter(data: &[f32], world_size: usize) -> Vec<Vec<f32>> {
if world_size == 0 {
return Vec::new();
}
let base = data.len() / world_size;
let remainder = data.len() % world_size;
(0..world_size)
.map(|rank| {
let start = rank * base + rank.min(remainder);
let end = start + base + if rank < remainder { 1 } else { 0 };
data[start..end.min(data.len())].to_vec()
})
.collect()
}
pub fn broadcast(data: &[f32], world_size: usize) -> Vec<Vec<f32>> {
(0..world_size).map(|_| data.to_vec()).collect()
}
}
pub fn partition_weights_column(
weights: &[f32],
rows: usize,
cols: usize,
world_size: usize,
) -> Vec<Vec<f32>> {
if world_size == 0 {
return Vec::new();
}
let base_cols = cols / world_size;
let remainder = cols % world_size;
(0..world_size)
.map(|rank| {
let col_start = rank * base_cols + rank.min(remainder);
let shard_cols = base_cols + if rank < remainder { 1 } else { 0 };
let mut shard = Vec::with_capacity(rows * shard_cols);
for row in 0..rows {
let row_base = row * cols;
shard.extend_from_slice(
&weights[row_base + col_start..row_base + col_start + shard_cols],
);
}
shard
})
.collect()
}
pub fn partition_weights_row(
weights: &[f32],
rows: usize,
cols: usize,
world_size: usize,
) -> Vec<Vec<f32>> {
if world_size == 0 {
return Vec::new();
}
let base_rows = rows / world_size;
let remainder = rows % world_size;
(0..world_size)
.map(|rank| {
let row_start = rank * base_rows + rank.min(remainder);
let shard_rows = base_rows + if rank < remainder { 1 } else { 0 };
weights[row_start * cols..(row_start + shard_rows) * cols].to_vec()
})
.collect()
}
pub fn merge_column_shards(shards: &[Vec<f32>], rows: usize) -> Vec<f32> {
if shards.is_empty() || rows == 0 {
return Vec::new();
}
let total_cols: usize = shards.iter().map(|s| s.len() / rows).sum();
let mut result = vec![0.0f32; rows * total_cols];
let mut col_offset = 0usize;
for shard in shards {
let shard_cols = shard.len() / rows;
for row in 0..rows {
let dst_start = row * total_cols + col_offset;
let src_start = row * shard_cols;
result[dst_start..dst_start + shard_cols]
.copy_from_slice(&shard[src_start..src_start + shard_cols]);
}
col_offset += shard_cols;
}
result
}