use std::time::Duration;
use ferrotorch_core::FerrotorchResult;
use crate::backend::{Backend, TcpBackend};
use crate::nccl_backend::NcclBackend;
use crate::nccl_sys::NcclUniqueId;
pub struct HybridBackend {
tcp: TcpBackend,
nccl: NcclBackend,
}
impl HybridBackend {
pub fn new(
rank: usize,
world_size: usize,
addr: &str,
unique_id: NcclUniqueId,
) -> FerrotorchResult<Self> {
let tcp = TcpBackend::new(rank, world_size, addr)?;
let nccl = NcclBackend::new(rank, world_size, unique_id)?;
Ok(Self { tcp, nccl })
}
pub fn nccl(&self) -> &NcclBackend {
&self.nccl
}
pub fn tcp(&self) -> &TcpBackend {
&self.tcp
}
pub fn synchronize_nccl(&self) -> FerrotorchResult<()> {
self.nccl.synchronize()
}
}
impl Backend for HybridBackend {
fn rank(&self) -> usize {
self.tcp.rank()
}
fn world_size(&self) -> usize {
self.tcp.world_size()
}
fn send(&self, data: &[u8], dst_rank: usize) -> FerrotorchResult<()> {
self.tcp.send(data, dst_rank)
}
fn recv(&self, dst: &mut [u8], src_rank: usize) -> FerrotorchResult<()> {
self.tcp.recv(dst, src_rank)
}
fn recv_timeout(
&self,
dst: &mut [u8],
src_rank: usize,
timeout: Duration,
) -> FerrotorchResult<()> {
self.tcp.recv_timeout(dst, src_rank, timeout)
}
fn barrier(&self) -> FerrotorchResult<()> {
self.tcp.barrier()
}
}