use oxicuda_driver::{CudaError, CudaResult};
use std::collections::HashMap;
use std::ops::Range;
use std::path::{Path, PathBuf};
use std::sync::{Arc, Mutex};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct NodeId(pub u32);
impl NodeId {
pub fn new(id: u32) -> Self {
Self(id)
}
pub fn value(&self) -> u32 {
self.0
}
}
impl std::fmt::Display for NodeId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Node({})", self.0)
}
}
#[derive(Debug, Clone)]
pub struct NodeInfo {
pub node_id: NodeId,
pub hostname: String,
pub ip_addr: String,
pub port: u16,
pub gpu_count: u32,
pub rank: u32,
}
impl NodeInfo {
pub fn new(
node_id: NodeId,
hostname: &str,
ip_addr: &str,
port: u16,
gpu_count: u32,
rank: u32,
) -> Self {
Self {
node_id,
hostname: hostname.to_string(),
ip_addr: ip_addr.to_string(),
port,
gpu_count,
rank,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
pub enum DistributedBackend {
#[default]
Tcp,
SharedMemory,
}
#[derive(Debug, Clone)]
pub enum InitMethod {
TcpRendezvous {
master_addr: String,
master_port: u16,
},
EnvVars,
FileStore(PathBuf),
}
#[derive(Debug, Clone)]
pub struct DistributedConfig {
pub world_size: u32,
pub local_rank: u32,
pub global_rank: u32,
pub master_addr: String,
pub master_port: u16,
pub backend: DistributedBackend,
}
impl DistributedConfig {
pub fn validate(&self) -> CudaResult<()> {
if self.world_size == 0 {
return Err(CudaError::InvalidValue);
}
if self.global_rank >= self.world_size {
return Err(CudaError::InvalidValue);
}
if self.master_addr.is_empty() {
return Err(CudaError::InvalidValue);
}
if self.master_port == 0 {
return Err(CudaError::InvalidValue);
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct ProcessGroup {
pub group_id: u32,
pub ranks: Vec<u32>,
pub size: u32,
}
impl ProcessGroup {
pub fn new(group_id: u32, ranks: Vec<u32>) -> CudaResult<Self> {
if ranks.is_empty() {
return Err(CudaError::InvalidValue);
}
let size = ranks.len() as u32;
Ok(Self {
group_id,
ranks,
size,
})
}
pub fn contains_rank(&self, rank: u32) -> bool {
self.ranks.contains(&rank)
}
pub fn local_rank(&self, rank: u32) -> Option<usize> {
self.ranks.iter().position(|&r| r == rank)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum DistributedStatus {
Initializing,
Ready,
Synchronizing,
Error(String),
Shutdown,
}
pub struct DistributedRuntime {
config: DistributedConfig,
status: Arc<Mutex<DistributedStatus>>,
barrier_epoch: Arc<Mutex<u64>>,
}
impl DistributedRuntime {
pub fn init(config: DistributedConfig) -> CudaResult<Self> {
config.validate()?;
let rt = Self {
config,
status: Arc::new(Mutex::new(DistributedStatus::Ready)),
barrier_epoch: Arc::new(Mutex::new(0)),
};
Ok(rt)
}
pub fn from_env() -> CudaResult<Self> {
let master_addr = std::env::var("MASTER_ADDR").map_err(|_| CudaError::InvalidValue)?;
let master_port: u16 = std::env::var("MASTER_PORT")
.map_err(|_| CudaError::InvalidValue)?
.parse()
.map_err(|_| CudaError::InvalidValue)?;
let rank: u32 = std::env::var("RANK")
.map_err(|_| CudaError::InvalidValue)?
.parse()
.map_err(|_| CudaError::InvalidValue)?;
let world_size: u32 = std::env::var("WORLD_SIZE")
.map_err(|_| CudaError::InvalidValue)?
.parse()
.map_err(|_| CudaError::InvalidValue)?;
let local_rank: u32 = std::env::var("LOCAL_RANK")
.unwrap_or_else(|_| rank.to_string())
.parse()
.map_err(|_| CudaError::InvalidValue)?;
let config = DistributedConfig {
world_size,
local_rank,
global_rank: rank,
master_addr,
master_port,
backend: DistributedBackend::Tcp,
};
Self::init(config)
}
pub fn world_size(&self) -> u32 {
self.config.world_size
}
pub fn rank(&self) -> u32 {
self.config.global_rank
}
pub fn local_rank(&self) -> u32 {
self.config.local_rank
}
pub fn is_master(&self) -> bool {
self.config.global_rank == 0
}
pub fn barrier(&self) -> CudaResult<()> {
let mut status = self.status.lock().map_err(|_| CudaError::InvalidValue)?;
if *status == DistributedStatus::Shutdown {
return Err(CudaError::NotInitialized);
}
*status = DistributedStatus::Synchronizing;
let mut epoch = self
.barrier_epoch
.lock()
.map_err(|_| CudaError::InvalidValue)?;
*epoch += 1;
*status = DistributedStatus::Ready;
Ok(())
}
pub fn status(&self) -> DistributedStatus {
self.status
.lock()
.map(|s| s.clone())
.unwrap_or_else(|_| DistributedStatus::Error("lock poisoned".to_string()))
}
pub fn shutdown(&self) -> CudaResult<()> {
let mut status = self.status.lock().map_err(|_| CudaError::InvalidValue)?;
*status = DistributedStatus::Shutdown;
Ok(())
}
}
pub struct TcpStore {
_master_addr: String,
_port: u16,
_world_size: u32,
is_master: bool,
data: Arc<Mutex<HashMap<String, Vec<u8>>>>,
counters: Arc<Mutex<HashMap<String, i64>>>,
}
impl TcpStore {
pub fn new(master_addr: &str, port: u16, world_size: u32, is_master: bool) -> CudaResult<Self> {
if master_addr.is_empty() || world_size == 0 {
return Err(CudaError::InvalidValue);
}
Ok(Self {
_master_addr: master_addr.to_string(),
_port: port,
_world_size: world_size,
is_master,
data: Arc::new(Mutex::new(HashMap::new())),
counters: Arc::new(Mutex::new(HashMap::new())),
})
}
pub fn is_master(&self) -> bool {
self.is_master
}
pub fn set(&self, key: &str, value: &[u8]) -> CudaResult<()> {
let mut data = self.data.lock().map_err(|_| CudaError::InvalidValue)?;
data.insert(key.to_string(), value.to_vec());
Ok(())
}
pub fn get(&self, key: &str) -> CudaResult<Vec<u8>> {
let data = self.data.lock().map_err(|_| CudaError::InvalidValue)?;
data.get(key).cloned().ok_or(CudaError::InvalidValue)
}
pub fn wait(&self, keys: &[&str]) -> CudaResult<()> {
let data = self.data.lock().map_err(|_| CudaError::InvalidValue)?;
for &k in keys {
if !data.contains_key(k) {
return Err(CudaError::NotReady);
}
}
Ok(())
}
pub fn add(&self, key: &str, amount: i64) -> CudaResult<i64> {
let mut counters = self.counters.lock().map_err(|_| CudaError::InvalidValue)?;
let entry = counters.entry(key.to_string()).or_insert(0);
*entry += amount;
Ok(*entry)
}
}
pub struct FileStore {
root: PathBuf,
}
impl FileStore {
pub fn new(path: &Path) -> CudaResult<Self> {
std::fs::create_dir_all(path).map_err(|_| CudaError::InvalidValue)?;
Ok(Self {
root: path.to_path_buf(),
})
}
fn key_path(&self, key: &str) -> PathBuf {
let safe: String = key
.chars()
.map(|c| {
if c.is_alphanumeric() || c == '_' || c == '-' {
c
} else {
'_'
}
})
.collect();
self.root.join(safe)
}
pub fn set(&self, key: &str, value: &[u8]) -> CudaResult<()> {
std::fs::write(self.key_path(key), value).map_err(|_| CudaError::InvalidValue)
}
pub fn get(&self, key: &str) -> CudaResult<Vec<u8>> {
std::fs::read(self.key_path(key)).map_err(|_| CudaError::InvalidValue)
}
pub fn wait(&self, keys: &[&str]) -> CudaResult<()> {
for &k in keys {
if !self.key_path(k).exists() {
return Err(CudaError::NotReady);
}
}
Ok(())
}
pub fn add(&self, key: &str, amount: i64) -> CudaResult<i64> {
let path = self.key_path(key);
let current: i64 = if path.exists() {
let bytes = std::fs::read(&path).map_err(|_| CudaError::InvalidValue)?;
let s = String::from_utf8(bytes).map_err(|_| CudaError::InvalidValue)?;
s.trim().parse().map_err(|_| CudaError::InvalidValue)?
} else {
0
};
let new_val = current + amount;
std::fs::write(&path, new_val.to_string().as_bytes())
.map_err(|_| CudaError::InvalidValue)?;
Ok(new_val)
}
}
#[derive(Debug, Clone)]
pub struct Bucket {
pub param_ids: Vec<usize>,
pub total_size: usize,
pub ready: bool,
}
#[derive(Debug, Clone)]
pub struct GradientBucket {
bucket_size_bytes: usize,
buckets: Vec<Bucket>,
}
impl GradientBucket {
pub fn new(bucket_size_mb: usize) -> Self {
Self {
bucket_size_bytes: bucket_size_mb * 1024 * 1024,
buckets: Vec::new(),
}
}
pub fn add_gradient(&mut self, param_id: usize, grad_size: usize) {
let needs_new = self.buckets.is_empty()
|| self
.buckets
.last()
.is_none_or(|b| b.total_size + grad_size > self.bucket_size_bytes);
if needs_new {
self.buckets.push(Bucket {
param_ids: vec![param_id],
total_size: grad_size,
ready: false,
});
} else if let Some(last) = self.buckets.last_mut() {
last.param_ids.push(param_id);
last.total_size += grad_size;
}
}
pub fn buckets(&self) -> &[Bucket] {
&self.buckets
}
pub fn mark_ready(&mut self, bucket_idx: usize) -> CudaResult<()> {
let bucket = self
.buckets
.get_mut(bucket_idx)
.ok_or(CudaError::InvalidValue)?;
bucket.ready = true;
Ok(())
}
pub fn num_buckets(&self) -> usize {
self.buckets.len()
}
pub fn bucket_capacity(&self) -> usize {
self.bucket_size_bytes
}
}
#[derive(Debug, Clone)]
pub struct DataParallelConfig {
pub gradient_bucket_size_mb: usize,
pub overlap_communication: bool,
pub find_unused_parameters: bool,
}
impl Default for DataParallelConfig {
fn default() -> Self {
Self {
gradient_bucket_size_mb: 25,
overlap_communication: true,
find_unused_parameters: false,
}
}
}
#[derive(Debug, Clone)]
pub struct ModelParallelConfig {
pub tensor_parallel_size: u32,
pub pipeline_parallel_size: u32,
pub sequence_parallel: bool,
}
impl ModelParallelConfig {
pub fn validate(&self) -> CudaResult<()> {
if self.tensor_parallel_size == 0 {
return Err(CudaError::InvalidValue);
}
if self.pipeline_parallel_size == 0 {
return Err(CudaError::InvalidValue);
}
Ok(())
}
pub fn total_gpus_required(&self) -> u32 {
self.tensor_parallel_size * self.pipeline_parallel_size
}
}
impl Default for ModelParallelConfig {
fn default() -> Self {
Self {
tensor_parallel_size: 1,
pipeline_parallel_size: 1,
sequence_parallel: false,
}
}
}
pub struct DistributedOptimizer;
impl DistributedOptimizer {
pub fn all_reduce_gradients(buckets: &[Bucket]) -> CudaResult<()> {
for (i, bucket) in buckets.iter().enumerate() {
if !bucket.ready {
return Err(CudaError::NotReady);
}
let _simulated_bytes = bucket.total_size;
let _bucket_id = i;
}
Ok(())
}
pub fn zero_redundancy_partition(world_size: u32, param_count: usize) -> Vec<Range<usize>> {
if world_size == 0 {
return Vec::new();
}
let ws = world_size as usize;
let base = param_count / ws;
let remainder = param_count % ws;
let mut ranges = Vec::with_capacity(ws);
let mut start = 0;
for i in 0..ws {
let extra = if i < remainder { 1 } else { 0 };
let end = start + base + extra;
ranges.push(start..end);
start = end;
}
ranges
}
}
#[cfg(test)]
mod tests {
use super::*;
fn sample_config() -> DistributedConfig {
DistributedConfig {
world_size: 4,
local_rank: 0,
global_rank: 0,
master_addr: "127.0.0.1".to_string(),
master_port: 29500,
backend: DistributedBackend::Tcp,
}
}
#[test]
fn config_valid() {
let cfg = sample_config();
assert!(cfg.validate().is_ok());
}
#[test]
fn config_invalid_world_size_zero() {
let mut cfg = sample_config();
cfg.world_size = 0;
assert!(cfg.validate().is_err());
}
#[test]
fn config_invalid_rank_exceeds_world() {
let mut cfg = sample_config();
cfg.global_rank = 10;
assert!(cfg.validate().is_err());
}
#[test]
fn tcp_store_set_get() {
let store = TcpStore::new("127.0.0.1", 29500, 2, true).expect("create store");
assert!(store.is_master());
store.set("key1", b"hello").expect("set");
let val = store.get("key1").expect("get");
assert_eq!(val, b"hello");
}
#[test]
fn tcp_store_get_missing_key() {
let store = TcpStore::new("127.0.0.1", 29500, 1, true).expect("create store");
assert!(store.get("nonexistent").is_err());
}
#[test]
fn tcp_store_add_counter() {
let store = TcpStore::new("127.0.0.1", 29500, 1, true).expect("create store");
let v1 = store.add("counter", 5).expect("add");
assert_eq!(v1, 5);
let v2 = store.add("counter", 3).expect("add");
assert_eq!(v2, 8);
}
#[test]
fn tcp_store_wait_present() {
let store = TcpStore::new("127.0.0.1", 29500, 1, true).expect("create store");
store.set("a", b"1").expect("set");
store.set("b", b"2").expect("set");
assert!(store.wait(&["a", "b"]).is_ok());
}
#[test]
fn tcp_store_wait_missing() {
let store = TcpStore::new("127.0.0.1", 29500, 1, true).expect("create store");
store.set("a", b"1").expect("set");
assert!(store.wait(&["a", "missing"]).is_err());
}
#[test]
fn file_store_set_get() {
let dir = std::env::temp_dir().join("oxicuda_test_filestore_setget");
let _ = std::fs::remove_dir_all(&dir);
let store = FileStore::new(&dir).expect("create file store");
store.set("mykey", b"world").expect("set");
let val = store.get("mykey").expect("get");
assert_eq!(val, b"world");
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn file_store_add_counter() {
let dir = std::env::temp_dir().join("oxicuda_test_filestore_add");
let _ = std::fs::remove_dir_all(&dir);
let store = FileStore::new(&dir).expect("create file store");
let v1 = store.add("ctr", 10).expect("add");
assert_eq!(v1, 10);
let v2 = store.add("ctr", -3).expect("add");
assert_eq!(v2, 7);
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn file_store_wait() {
let dir = std::env::temp_dir().join("oxicuda_test_filestore_wait");
let _ = std::fs::remove_dir_all(&dir);
let store = FileStore::new(&dir).expect("create file store");
store.set("x", b"1").expect("set");
assert!(store.wait(&["x"]).is_ok());
assert!(store.wait(&["x", "y"]).is_err());
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn process_group_creation() {
let pg = ProcessGroup::new(0, vec![0, 1, 2, 3]).expect("create pg");
assert_eq!(pg.size, 4);
assert_eq!(pg.group_id, 0);
assert!(pg.contains_rank(2));
assert!(!pg.contains_rank(5));
assert_eq!(pg.local_rank(3), Some(3));
assert_eq!(pg.local_rank(9), None);
}
#[test]
fn process_group_empty_rejected() {
assert!(ProcessGroup::new(0, vec![]).is_err());
}
#[test]
fn gradient_bucket_partitioning() {
let mut gb = GradientBucket::new(1);
let half_mb = 512 * 1024;
gb.add_gradient(0, half_mb);
gb.add_gradient(1, half_mb);
gb.add_gradient(2, half_mb);
assert_eq!(gb.num_buckets(), 2);
assert_eq!(gb.buckets()[0].param_ids, vec![0, 1]);
assert_eq!(gb.buckets()[1].param_ids, vec![2]);
}
#[test]
fn gradient_bucket_size_distribution() {
let mut gb = GradientBucket::new(2); let one_mb = 1024 * 1024;
for i in 0..5 {
gb.add_gradient(i, one_mb);
}
assert_eq!(gb.num_buckets(), 3);
assert_eq!(gb.buckets()[0].total_size, 2 * one_mb);
assert_eq!(gb.buckets()[2].param_ids.len(), 1);
}
#[test]
fn zero_sharding_even() {
let ranges = DistributedOptimizer::zero_redundancy_partition(4, 100);
assert_eq!(ranges.len(), 4);
assert_eq!(ranges[0], 0..25);
assert_eq!(ranges[1], 25..50);
assert_eq!(ranges[2], 50..75);
assert_eq!(ranges[3], 75..100);
}
#[test]
fn zero_sharding_uneven() {
let ranges = DistributedOptimizer::zero_redundancy_partition(3, 10);
assert_eq!(ranges.len(), 3);
assert_eq!(ranges[0], 0..4);
assert_eq!(ranges[1], 4..7);
assert_eq!(ranges[2], 7..10);
}
#[test]
fn zero_sharding_zero_world() {
let ranges = DistributedOptimizer::zero_redundancy_partition(0, 100);
assert!(ranges.is_empty());
}
#[test]
fn from_env_missing_vars() {
let _result = DistributedRuntime::from_env();
}
#[test]
fn barrier_increments_epoch() {
let rt = DistributedRuntime::init(sample_config()).expect("init");
rt.barrier().expect("barrier 1");
rt.barrier().expect("barrier 2");
let epoch = rt.barrier_epoch.lock().expect("lock");
assert_eq!(*epoch, 2);
}
#[test]
fn master_detection() {
let cfg = sample_config(); let rt = DistributedRuntime::init(cfg).expect("init");
assert!(rt.is_master());
let mut cfg2 = sample_config();
cfg2.global_rank = 2;
let rt2 = DistributedRuntime::init(cfg2).expect("init");
assert!(!rt2.is_master());
}
#[test]
fn world_size_rank_accessors() {
let mut cfg = sample_config();
cfg.world_size = 8;
cfg.global_rank = 3;
cfg.local_rank = 1;
let rt = DistributedRuntime::init(cfg).expect("init");
assert_eq!(rt.world_size(), 8);
assert_eq!(rt.rank(), 3);
assert_eq!(rt.local_rank(), 1);
}
#[test]
fn data_parallel_config_defaults() {
let dpc = DataParallelConfig::default();
assert_eq!(dpc.gradient_bucket_size_mb, 25);
assert!(dpc.overlap_communication);
assert!(!dpc.find_unused_parameters);
}
#[test]
fn model_parallel_config_validation() {
let mpc = ModelParallelConfig::default();
assert!(mpc.validate().is_ok());
assert_eq!(mpc.total_gpus_required(), 1);
let bad = ModelParallelConfig {
tensor_parallel_size: 0,
pipeline_parallel_size: 4,
sequence_parallel: false,
};
assert!(bad.validate().is_err());
}
#[test]
fn status_transitions() {
let rt = DistributedRuntime::init(sample_config()).expect("init");
assert_eq!(rt.status(), DistributedStatus::Ready);
rt.barrier().expect("barrier");
assert_eq!(rt.status(), DistributedStatus::Ready);
rt.shutdown().expect("shutdown");
assert_eq!(rt.status(), DistributedStatus::Shutdown);
}
#[test]
fn shutdown_idempotent() {
let rt = DistributedRuntime::init(sample_config()).expect("init");
rt.shutdown().expect("shutdown 1");
rt.shutdown().expect("shutdown 2");
assert_eq!(rt.status(), DistributedStatus::Shutdown);
}
#[test]
fn barrier_after_shutdown_fails() {
let rt = DistributedRuntime::init(sample_config()).expect("init");
rt.shutdown().expect("shutdown");
assert!(rt.barrier().is_err());
}
#[test]
fn all_reduce_gradients_ready() {
let buckets = vec![
Bucket {
param_ids: vec![0, 1],
total_size: 1024,
ready: true,
},
Bucket {
param_ids: vec![2],
total_size: 512,
ready: true,
},
];
assert!(DistributedOptimizer::all_reduce_gradients(&buckets).is_ok());
}
#[test]
fn all_reduce_gradients_not_ready() {
let buckets = vec![
Bucket {
param_ids: vec![0],
total_size: 1024,
ready: true,
},
Bucket {
param_ids: vec![1],
total_size: 512,
ready: false,
},
];
assert!(DistributedOptimizer::all_reduce_gradients(&buckets).is_err());
}
#[test]
fn node_info_creation() {
let ni = NodeInfo::new(NodeId::new(0), "host0", "10.0.0.1", 8080, 4, 0);
assert_eq!(ni.node_id, NodeId(0));
assert_eq!(ni.hostname, "host0");
assert_eq!(ni.gpu_count, 4);
}
#[test]
fn node_id_display() {
let id = NodeId::new(42);
assert_eq!(format!("{id}"), "Node(42)");
assert_eq!(id.value(), 42);
}
}