use std::time::Duration;
use ferrotorch_core::FerrotorchResult;
use crate::backend::Backend;
use crate::error::DistributedError;
pub fn is_mpi_available() -> bool {
cfg!(feature = "mpi-backend")
}
#[derive(Debug)]
pub struct MpiBackend {
rank: usize,
world_size: usize,
}
impl MpiBackend {
pub fn new(rank: usize, world_size: usize) -> FerrotorchResult<Self> {
if !is_mpi_available() {
return Err(DistributedError::BackendUnavailable { backend: "mpi" }.into());
}
Ok(Self { rank, world_size })
}
}
impl Backend for MpiBackend {
fn rank(&self) -> usize {
self.rank
}
fn world_size(&self) -> usize {
self.world_size
}
fn send(&self, _data: &[u8], _dst_rank: usize) -> FerrotorchResult<()> {
Err(DistributedError::BackendUnavailable { backend: "mpi" }.into())
}
fn recv(&self, _dst: &mut [u8], _src_rank: usize) -> FerrotorchResult<()> {
Err(DistributedError::BackendUnavailable { backend: "mpi" }.into())
}
fn recv_timeout(
&self,
_dst: &mut [u8],
_src_rank: usize,
_timeout: Duration,
) -> FerrotorchResult<()> {
Err(DistributedError::BackendUnavailable { backend: "mpi" }.into())
}
fn barrier(&self) -> FerrotorchResult<()> {
Err(DistributedError::BackendUnavailable { backend: "mpi" }.into())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn mpi_unavailable_without_feature() {
if !is_mpi_available() {
assert!(MpiBackend::new(0, 2).is_err());
}
}
}