use std::time::Duration;
use ferrotorch_core::FerrotorchResult;
use crate::backend::Backend;
#[cfg(not(feature = "gloo-backend"))]
use crate::error::DistributedError;
#[cfg(feature = "gloo-backend")]
mod native {
pub use crate::gloo_native::{GlooBackendInner, GlooRendezvousConfig};
}
pub fn is_gloo_available() -> bool {
cfg!(feature = "gloo-backend")
}
#[derive(Debug)]
pub struct GlooBackend {
#[cfg(feature = "gloo-backend")]
inner: native::GlooBackendInner,
#[cfg(not(feature = "gloo-backend"))]
_phantom: std::marker::PhantomData<()>,
}
impl GlooBackend {
#[allow(unused_variables)] pub fn new(rank: usize, world_size: usize, master_addr: &str) -> FerrotorchResult<Self> {
#[cfg(feature = "gloo-backend")]
{
use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
let cfg = native::GlooRendezvousConfig {
master_addr: master_addr.to_string(),
rank,
world_size,
bind_addr: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0)),
};
let inner = native::GlooBackendInner::new(&cfg)?;
Ok(Self { inner })
}
#[cfg(not(feature = "gloo-backend"))]
{
Err(DistributedError::BackendUnavailable { backend: "gloo" }.into())
}
}
pub fn from_env() -> FerrotorchResult<Self> {
#[cfg(feature = "gloo-backend")]
{
let cfg = native::GlooRendezvousConfig::from_env()?;
let inner = native::GlooBackendInner::new(&cfg)?;
Ok(Self { inner })
}
#[cfg(not(feature = "gloo-backend"))]
{
Err(DistributedError::BackendUnavailable { backend: "gloo" }.into())
}
}
#[cfg(feature = "gloo-backend")]
pub fn ring_allreduce_sum_f32(&self, data: &mut [f32]) -> FerrotorchResult<()> {
self.inner.ring_allreduce_sum_f32(data)
}
#[cfg(feature = "gloo-backend")]
pub fn tree_broadcast_f32(&self, data: &mut [f32], root: usize) -> FerrotorchResult<()> {
self.inner.tree_broadcast_f32(data, root)
}
}
impl Backend for GlooBackend {
fn rank(&self) -> usize {
#[cfg(feature = "gloo-backend")]
{
self.inner.rank()
}
#[cfg(not(feature = "gloo-backend"))]
{
0
}
}
fn world_size(&self) -> usize {
#[cfg(feature = "gloo-backend")]
{
Backend::world_size(&self.inner)
}
#[cfg(not(feature = "gloo-backend"))]
{
0
}
}
#[allow(unused_variables)]
fn send(&self, data: &[u8], dst_rank: usize) -> FerrotorchResult<()> {
#[cfg(feature = "gloo-backend")]
{
self.inner.send(data, dst_rank)
}
#[cfg(not(feature = "gloo-backend"))]
{
Err(DistributedError::BackendUnavailable { backend: "gloo" }.into())
}
}
#[allow(unused_variables)]
fn recv(&self, dst: &mut [u8], src_rank: usize) -> FerrotorchResult<()> {
#[cfg(feature = "gloo-backend")]
{
self.inner.recv(dst, src_rank)
}
#[cfg(not(feature = "gloo-backend"))]
{
Err(DistributedError::BackendUnavailable { backend: "gloo" }.into())
}
}
#[allow(unused_variables)]
fn recv_timeout(
&self,
dst: &mut [u8],
src_rank: usize,
timeout: Duration,
) -> FerrotorchResult<()> {
#[cfg(feature = "gloo-backend")]
{
self.inner.recv_timeout(dst, src_rank, timeout)
}
#[cfg(not(feature = "gloo-backend"))]
{
Err(DistributedError::BackendUnavailable { backend: "gloo" }.into())
}
}
fn barrier(&self) -> FerrotorchResult<()> {
#[cfg(feature = "gloo-backend")]
{
self.inner.barrier()
}
#[cfg(not(feature = "gloo-backend"))]
{
Err(DistributedError::BackendUnavailable { backend: "gloo" }.into())
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg(not(feature = "gloo-backend"))]
use ferrotorch_core::FerrotorchError;
#[cfg(not(feature = "gloo-backend"))]
#[test]
fn gloo_unavailable_without_feature() {
let err = GlooBackend::new(0, 2, "127.0.0.1:0").expect_err("default build must err");
match err {
FerrotorchError::InvalidArgument { ref message } => {
assert!(
message.contains("`gloo`"),
"expected message to discriminate the gloo backend by name, got: {message}"
);
assert!(
!message.contains("`mpi`") && !message.contains("`ucc`"),
"message must not name a different backend, got: {message}"
);
}
other => panic!(
"expected FerrotorchError::InvalidArgument from BackendUnavailable, got {other:?}"
),
}
}
#[cfg(not(feature = "gloo-backend"))]
#[test]
fn gloo_from_env_unavailable_without_feature() {
let err = GlooBackend::from_env().expect_err("default build must err");
match err {
FerrotorchError::InvalidArgument { message } => {
assert!(message.contains("`gloo`"));
}
other => panic!("expected InvalidArgument, got {other:?}"),
}
}
#[test]
fn is_gloo_available_default_off() {
if !cfg!(feature = "gloo-backend") {
assert!(!is_gloo_available());
}
}
}