use crate::ibverbs::error::IbvResult;
use crate::ibverbs::protection_domain::ProtectionDomain;
use crate::multi_channel::MultiChannel;
use crate::multi_channel::PeerRemoteMemoryRegion;
use crate::network::barrier::memory::{BarrierMr, PreparedBarrierMr};
use crate::network::barrier::{BarrierError, validate_peer_list};
use std::time::{Duration, Instant};
#[derive(Debug)]
pub struct LinearBarrier {
rank: usize,
barrier_mr: BarrierMr,
poisoned: bool,
}
#[derive(Debug)]
pub struct PreparedLinearBarrier {
rank: usize,
barrier_mr: PreparedBarrierMr,
}
impl PreparedLinearBarrier {
pub fn new(
pd: &ProtectionDomain,
rank: usize,
world_size: usize,
) -> IbvResult<PreparedLinearBarrier> {
Ok(PreparedLinearBarrier {
rank,
barrier_mr: PreparedBarrierMr::new(pd, rank, world_size)?,
})
}
pub fn remote(&self) -> PeerRemoteMemoryRegion {
self.barrier_mr.remote()
}
pub fn link_remote(self, remote_mrs: Box<[PeerRemoteMemoryRegion]>) -> LinearBarrier {
LinearBarrier {
rank: self.rank,
barrier_mr: self.barrier_mr.link_remote(remote_mrs),
poisoned: false,
}
}
}
impl LinearBarrier {
pub fn barrier(
&mut self,
multi_channel: &mut MultiChannel,
peers: &[usize],
timeout: Duration,
) -> Result<(), BarrierError> {
validate_peer_list(peers)?;
peers
.binary_search(&self.rank)
.map_err(|_| BarrierError::SelfNotInGroup)?;
self.barrier_unchecked(multi_channel, peers, timeout)
}
pub fn barrier_unchecked(
&mut self,
multi_channel: &mut MultiChannel,
peers: &[usize],
timeout: Duration,
) -> Result<(), BarrierError> {
if self.poisoned {
return Err(BarrierError::Poisoned);
}
let result = self.run_barrier(multi_channel, peers, timeout);
if result.is_err() {
self.poisoned = true;
}
result
}
fn run_barrier(
&mut self,
multi_channel: &mut MultiChannel,
peers: &[usize],
timeout: Duration,
) -> Result<(), BarrierError> {
if peers.len() < 2 {
return Ok(());
}
let start_time = Instant::now();
let leader = peers[0];
if self.rank == leader {
for &peer in &peers[1..] {
self.barrier_mr.increase_peer_expected_epoch(peer);
self.barrier_mr
.spin_poll_peer_epoch_expected(peer, start_time, timeout)?;
}
self.barrier_mr
.scatter_notify_peers(multi_channel, &peers[1..])?;
} else {
self.barrier_mr.notify_peer(multi_channel, leader)?;
self.barrier_mr.increase_peer_expected_epoch(leader);
self.barrier_mr
.spin_poll_peer_epoch_expected(leader, start_time, timeout)?;
}
Ok(())
}
}