use std::time::Duration;
use ferrotorch_core::FerrotorchResult;
use crate::backend::Backend;
#[cfg(not(feature = "mpi-native"))]
use crate::error::DistributedError;
#[cfg(feature = "mpi-native")]
mod native {
use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
use crate::error::DistributedError;
use crate::gloo_native::GlooRendezvousConfig;
pub(super) fn mpi_rendezvous_from_env() -> Result<GlooRendezvousConfig, DistributedError> {
fn try_pair(rank_key: &str, size_key: &str) -> Option<(usize, usize)> {
let rank = std::env::var(rank_key).ok()?.parse::<usize>().ok()?;
let size = std::env::var(size_key).ok()?.parse::<usize>().ok()?;
Some((rank, size))
}
let (rank, world_size) = try_pair("OMPI_COMM_WORLD_RANK", "OMPI_COMM_WORLD_SIZE")
.or_else(|| try_pair("PMI_RANK", "PMI_SIZE"))
.or_else(|| try_pair("RANK", "WORLD_SIZE"))
.ok_or_else(|| DistributedError::Io {
message: "mpi_backend rendezvous: none of (OMPI_COMM_WORLD_RANK + \
OMPI_COMM_WORLD_SIZE), (PMI_RANK + PMI_SIZE), (RANK + WORLD_SIZE) \
are set in the environment"
.to_string(),
})?;
let master_host = std::env::var("MASTER_ADDR").map_err(|_| DistributedError::Io {
message: "mpi_backend rendezvous: env var `MASTER_ADDR` is not set".to_string(),
})?;
let master_port = std::env::var("MASTER_PORT").map_err(|_| DistributedError::Io {
message: "mpi_backend rendezvous: env var `MASTER_PORT` is not set".to_string(),
})?;
Ok(GlooRendezvousConfig {
master_addr: format!("{master_host}:{master_port}"),
rank,
world_size,
bind_addr: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0)),
})
}
}
pub fn is_mpi_available() -> bool {
cfg!(feature = "mpi-native")
}
#[derive(Debug)]
pub struct MpiBackend {
#[cfg(feature = "mpi-native")]
inner: crate::gloo_native::GlooBackendInner,
#[cfg(not(feature = "mpi-native"))]
_phantom: std::marker::PhantomData<()>,
}
impl MpiBackend {
#[allow(unused_variables)] pub fn new(rank: usize, world_size: usize, master_addr: &str) -> FerrotorchResult<Self> {
#[cfg(feature = "mpi-native")]
{
use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
let cfg = crate::gloo_native::GlooRendezvousConfig {
master_addr: master_addr.to_string(),
rank,
world_size,
bind_addr: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0)),
};
let inner = crate::gloo_native::GlooBackendInner::new(&cfg)?;
Ok(Self { inner })
}
#[cfg(not(feature = "mpi-native"))]
{
Err(DistributedError::BackendUnavailable { backend: "mpi" }.into())
}
}
pub fn from_env() -> FerrotorchResult<Self> {
#[cfg(feature = "mpi-native")]
{
let cfg = native::mpi_rendezvous_from_env()?;
let inner = crate::gloo_native::GlooBackendInner::new(&cfg)?;
Ok(Self { inner })
}
#[cfg(not(feature = "mpi-native"))]
{
Err(DistributedError::BackendUnavailable { backend: "mpi" }.into())
}
}
#[cfg(feature = "mpi-native")]
pub fn allreduce_sum_f32(&self, data: &mut [f32]) -> FerrotorchResult<()> {
self.inner.ring_allreduce_sum_f32(data)
}
#[cfg(feature = "mpi-native")]
pub fn broadcast_f32(&self, data: &mut [f32], root: usize) -> FerrotorchResult<()> {
self.inner.tree_broadcast_f32(data, root)
}
}
impl Backend for MpiBackend {
fn rank(&self) -> usize {
#[cfg(feature = "mpi-native")]
{
self.inner.rank()
}
#[cfg(not(feature = "mpi-native"))]
{
0
}
}
fn world_size(&self) -> usize {
#[cfg(feature = "mpi-native")]
{
Backend::world_size(&self.inner)
}
#[cfg(not(feature = "mpi-native"))]
{
0
}
}
#[allow(unused_variables)]
fn send(&self, data: &[u8], dst_rank: usize) -> FerrotorchResult<()> {
#[cfg(feature = "mpi-native")]
{
self.inner.send(data, dst_rank)
}
#[cfg(not(feature = "mpi-native"))]
{
Err(DistributedError::BackendUnavailable { backend: "mpi" }.into())
}
}
#[allow(unused_variables)]
fn recv(&self, dst: &mut [u8], src_rank: usize) -> FerrotorchResult<()> {
#[cfg(feature = "mpi-native")]
{
self.inner.recv(dst, src_rank)
}
#[cfg(not(feature = "mpi-native"))]
{
Err(DistributedError::BackendUnavailable { backend: "mpi" }.into())
}
}
#[allow(unused_variables)]
fn recv_timeout(
&self,
dst: &mut [u8],
src_rank: usize,
timeout: Duration,
) -> FerrotorchResult<()> {
#[cfg(feature = "mpi-native")]
{
self.inner.recv_timeout(dst, src_rank, timeout)
}
#[cfg(not(feature = "mpi-native"))]
{
Err(DistributedError::BackendUnavailable { backend: "mpi" }.into())
}
}
fn barrier(&self) -> FerrotorchResult<()> {
#[cfg(feature = "mpi-native")]
{
self.inner.barrier()
}
#[cfg(not(feature = "mpi-native"))]
{
Err(DistributedError::BackendUnavailable { backend: "mpi" }.into())
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg(not(feature = "mpi-native"))]
use ferrotorch_core::FerrotorchError;
#[cfg(not(feature = "mpi-native"))]
#[test]
fn mpi_unavailable_without_feature() {
let err = MpiBackend::new(0, 2, "127.0.0.1:0").expect_err("default build must err");
match err {
FerrotorchError::InvalidArgument { ref message } => {
assert!(
message.contains("`mpi`"),
"expected message to discriminate the mpi backend by name, got: {message}"
);
assert!(
!message.contains("`gloo`") && !message.contains("`ucc`"),
"message must not name a different backend, got: {message}"
);
}
other => panic!(
"expected FerrotorchError::InvalidArgument from BackendUnavailable, got {other:?}"
),
}
}
#[cfg(not(feature = "mpi-native"))]
#[test]
fn mpi_from_env_unavailable_without_feature() {
let err = MpiBackend::from_env().expect_err("default build must err");
match err {
FerrotorchError::InvalidArgument { message } => {
assert!(message.contains("`mpi`"));
}
other => panic!("expected InvalidArgument, got {other:?}"),
}
}
#[test]
fn is_mpi_available_default_off() {
if !cfg!(feature = "mpi-native") {
assert!(!is_mpi_available());
}
}
#[cfg(feature = "mpi-native")]
#[test]
fn mpi_native_e2e_allreduce_two_ranks() {
use std::net::TcpListener;
use std::sync::Arc;
use std::thread;
let probe = TcpListener::bind("127.0.0.1:0").expect("probe bind");
let master_addr = probe.local_addr().expect("local_addr").to_string();
drop(probe);
let world_size = 2usize;
let handles: Vec<_> = (0..world_size)
.map(|rank| {
let ma = master_addr.clone();
thread::spawn(move || {
Arc::new(MpiBackend::new(rank, world_size, &ma).expect("MpiBackend::new"))
})
})
.collect();
let backends: Vec<_> = handles
.into_iter()
.map(|h| h.join().expect("join"))
.collect();
thread::scope(|s| {
let b0 = Arc::clone(&backends[0]);
let b1 = Arc::clone(&backends[1]);
let h0 = s.spawn(move || {
let mut a = vec![1.0f32, 2.0, 3.0, 4.0];
b0.allreduce_sum_f32(&mut a).expect("allreduce rank 0");
a
});
let h1 = s.spawn(move || {
let mut a = vec![10.0f32, 20.0, 30.0, 40.0];
b1.allreduce_sum_f32(&mut a).expect("allreduce rank 1");
a
});
let r0 = h0.join().unwrap();
let r1 = h1.join().unwrap();
let expected = vec![11.0f32, 22.0, 33.0, 44.0];
assert_eq!(r0, expected, "rank 0 allreduce result");
assert_eq!(r1, expected, "rank 1 allreduce result");
});
}
#[cfg(feature = "mpi-native")]
#[test]
fn mpi_native_e2e_broadcast_and_barrier_three_ranks() {
use std::net::TcpListener;
use std::sync::Arc;
use std::thread;
let probe = TcpListener::bind("127.0.0.1:0").expect("probe bind");
let master_addr = probe.local_addr().expect("local_addr").to_string();
drop(probe);
let world_size = 3usize;
let handles: Vec<_> = (0..world_size)
.map(|rank| {
let ma = master_addr.clone();
thread::spawn(move || {
Arc::new(MpiBackend::new(rank, world_size, &ma).expect("MpiBackend::new"))
})
})
.collect();
let backends: Vec<_> = handles
.into_iter()
.map(|h| h.join().expect("join"))
.collect();
let payload = vec![7.5f32, 8.25, 9.125];
let root = 1usize;
thread::scope(|s| {
let mut handles = Vec::new();
for (rank, backend) in backends.iter().enumerate() {
let b = Arc::clone(backend);
let p = payload.clone();
handles.push(s.spawn(move || {
let mut data = if rank == root { p } else { vec![0.0f32; 3] };
b.broadcast_f32(&mut data, root).expect("broadcast");
Backend::barrier(&*b).expect("barrier");
data
}));
}
for h in handles {
let got = h.join().unwrap();
assert_eq!(got, vec![7.5f32, 8.25, 9.125]);
}
});
}
#[cfg(feature = "mpi-native")]
#[test]
fn mpi_native_from_env_prefers_ompi_then_pmi_then_pytorch() {
unsafe {
std::env::remove_var("OMPI_COMM_WORLD_RANK");
std::env::remove_var("OMPI_COMM_WORLD_SIZE");
std::env::remove_var("PMI_RANK");
std::env::remove_var("PMI_SIZE");
std::env::remove_var("RANK");
std::env::remove_var("WORLD_SIZE");
std::env::set_var("MASTER_ADDR", "127.0.0.1");
std::env::set_var("MASTER_PORT", "29555");
std::env::set_var("RANK", "3");
std::env::set_var("WORLD_SIZE", "4");
let cfg = native::mpi_rendezvous_from_env().expect("pytorch fallback");
assert_eq!(cfg.rank, 3, "pytorch fallback rank");
assert_eq!(cfg.world_size, 4, "pytorch fallback world_size");
assert_eq!(cfg.master_addr, "127.0.0.1:29555");
std::env::set_var("PMI_RANK", "5");
std::env::set_var("PMI_SIZE", "8");
let cfg = native::mpi_rendezvous_from_env().expect("pmi over pytorch");
assert_eq!(cfg.rank, 5, "pmi rank wins over RANK");
assert_eq!(cfg.world_size, 8, "pmi size wins over WORLD_SIZE");
std::env::set_var("OMPI_COMM_WORLD_RANK", "7");
std::env::set_var("OMPI_COMM_WORLD_SIZE", "16");
let cfg = native::mpi_rendezvous_from_env().expect("ompi over pmi");
assert_eq!(cfg.rank, 7, "ompi rank wins over pmi");
assert_eq!(cfg.world_size, 16, "ompi size wins over pmi");
std::env::remove_var("OMPI_COMM_WORLD_RANK");
std::env::remove_var("OMPI_COMM_WORLD_SIZE");
std::env::remove_var("PMI_RANK");
std::env::remove_var("PMI_SIZE");
std::env::remove_var("RANK");
std::env::remove_var("WORLD_SIZE");
std::env::remove_var("MASTER_ADDR");
std::env::remove_var("MASTER_PORT");
}
}
}