use crate::{Backend, ProcessGroup, TorshDistributedError, TorshResult};
pub fn with_backend_read<T, F>(process_group: &ProcessGroup, f: F) -> TorshResult<T>
where
F: FnOnce(&dyn Backend) -> TorshResult<T>,
{
let backend = process_group.backend();
let backend_guard = backend.read();
validate_backend_initialized(&**backend_guard)?;
f(&**backend_guard)
}
pub fn with_backend_write<T, F>(process_group: &ProcessGroup, f: F) -> TorshResult<T>
where
F: FnOnce(&mut dyn Backend) -> TorshResult<T>,
{
let backend = process_group.backend();
let mut backend_guard = backend.write();
validate_backend_initialized(&**backend_guard)?;
f(&mut **backend_guard)
}
pub fn validate_rank(rank: u32, world_size: u32) -> TorshResult<()> {
if rank >= world_size {
return Err(TorshDistributedError::RankOutOfBounds { rank, world_size });
}
Ok(())
}
pub fn validate_backend_initialized(backend: &dyn Backend) -> TorshResult<()> {
if !backend.is_ready() {
return Err(TorshDistributedError::BackendNotInitialized);
}
Ok(())
}
pub fn validate_ranks(ranks: &[u32], world_size: u32) -> TorshResult<()> {
for &rank in ranks {
validate_rank(rank, world_size)?;
}
Ok(())
}
pub fn is_root_process(process_group: &ProcessGroup) -> bool {
process_group.rank() == 0
}
pub fn get_world_size(process_group: &ProcessGroup) -> u32 {
process_group.world_size()
}
pub fn get_rank(process_group: &ProcessGroup) -> u32 {
process_group.rank()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::BackendType;
#[test]
fn test_validate_rank() {
assert!(validate_rank(0, 4).is_ok());
assert!(validate_rank(3, 4).is_ok());
assert!(validate_rank(4, 4).is_err());
assert!(validate_rank(10, 4).is_err());
}
#[test]
fn test_validate_ranks() {
assert!(validate_ranks(&[0, 1, 2], 4).is_ok());
assert!(validate_ranks(&[0, 4, 2], 4).is_err());
assert!(validate_ranks(&[0, 1, 10], 4).is_err());
}
#[tokio::test]
async fn test_is_root_process() {
let pg = ProcessGroup::new(BackendType::Gloo, 0, 4, "localhost", 8080)
.await
.expect("operation should succeed");
assert!(is_root_process(&pg));
}
}