use crate::ibverbs::error::IbvResult;
use crate::ibverbs::protection_domain::ProtectionDomain;
use crate::multi_channel::MultiChannel;
use crate::multi_channel::PeerRemoteMemoryRegion;
use crate::network::barrier::memory::{BarrierMr, PreparedBarrierMr};
use crate::network::barrier::{BarrierError, validate_peer_list};
use std::time::{Duration, Instant};
fn round_pairs(idx: usize, len: usize) -> impl Iterator<Item = (usize, usize)> {
std::iter::successors(Some(1usize), |d| d.checked_mul(2))
.take_while(move |&d| d < len)
.map(move |d| ((idx + d) % len, (idx + len - d) % len))
}
#[derive(Debug)]
pub struct DisseminationBarrier {
rank: usize,
barrier_mr: BarrierMr,
poisoned: bool,
}
#[derive(Debug)]
pub struct PreparedDisseminationBarrier {
rank: usize,
barrier_mr: PreparedBarrierMr,
}
impl PreparedDisseminationBarrier {
pub fn new(
pd: &ProtectionDomain,
rank: usize,
world_size: usize,
) -> IbvResult<PreparedDisseminationBarrier> {
Ok(PreparedDisseminationBarrier {
rank,
barrier_mr: PreparedBarrierMr::new(pd, rank, world_size)?,
})
}
pub fn remote(&self) -> PeerRemoteMemoryRegion {
self.barrier_mr.remote()
}
pub fn link_remote(self, remote_mrs: Box<[PeerRemoteMemoryRegion]>) -> DisseminationBarrier {
DisseminationBarrier {
rank: self.rank,
barrier_mr: self.barrier_mr.link_remote(remote_mrs),
poisoned: false,
}
}
}
impl DisseminationBarrier {
pub fn barrier(
&mut self,
multi_channel: &mut MultiChannel,
peers: &[usize],
timeout: Duration,
) -> Result<(), BarrierError> {
validate_peer_list(peers)?;
self.barrier_unchecked(multi_channel, peers, timeout)
}
pub fn barrier_unchecked(
&mut self,
multi_channel: &mut MultiChannel,
peers: &[usize],
timeout: Duration,
) -> Result<(), BarrierError> {
if self.poisoned {
return Err(BarrierError::Poisoned);
}
let result = self.run_barrier(multi_channel, peers, timeout);
if result.is_err() {
self.poisoned = true;
}
result
}
fn run_barrier(
&mut self,
multi_channel: &mut MultiChannel,
peers: &[usize],
timeout: Duration,
) -> Result<(), BarrierError> {
if peers.len() < 2 {
return Ok(());
}
let start_time = Instant::now();
let idx = peers
.binary_search(&self.rank)
.map_err(|_| BarrierError::SelfNotInGroup)?;
for (right_idx, left_idx) in round_pairs(idx, peers.len()) {
let right_rank = peers[right_idx];
let left_rank = peers[left_idx];
self.barrier_mr.notify_peer(multi_channel, right_rank)?;
self.barrier_mr.increase_peer_expected_epoch(left_rank);
self.barrier_mr
.spin_poll_peer_epoch_expected(left_rank, start_time, timeout)?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn single_node_no_rounds() {
let rounds: Vec<_> = round_pairs(0, 1).collect();
assert!(rounds.is_empty());
}
#[test]
fn two_nodes_one_round() {
let rounds: Vec<_> = round_pairs(0, 2).collect();
assert_eq!(rounds.len(), 1);
}
#[test]
fn round_count_is_ceil_log2() {
let cases = [
(2, 1),
(3, 2),
(4, 2),
(5, 3),
(7, 3),
(8, 3),
(9, 4),
(16, 4),
];
for (len, expected_rounds) in cases {
let count = round_pairs(0, len).count();
assert_eq!(
count, expected_rounds,
"expected {expected_rounds} rounds for {len} nodes, got {count}"
);
}
}
#[test]
fn two_nodes_notify_each_other() {
let pairs_0: Vec<_> = round_pairs(0, 2).collect();
assert_eq!(pairs_0, vec![(1, 1)]);
let pairs_1: Vec<_> = round_pairs(1, 2).collect();
assert_eq!(pairs_1, vec![(0, 0)]);
}
#[test]
fn three_nodes_round_pairs() {
let pairs: Vec<_> = round_pairs(0, 3).collect();
assert_eq!(pairs, vec![(1, 2), (2, 1)]);
}
#[test]
fn notify_and_wait_are_symmetric() {
for len in 2..=16 {
for idx in 0..len {
for (round, (notify_target, _)) in round_pairs(idx, len).enumerate() {
let target_pairs: Vec<_> = round_pairs(notify_target, len).collect();
let (_, wait_source) = target_pairs[round];
assert_eq!(
wait_source, idx,
"broken symmetry: node {idx} notifies {notify_target} in round {round}, \
but {notify_target} waits for {wait_source} (expected {idx}) in len={len}"
);
}
}
}
}
#[test]
fn all_pairs_communicate_transitively() {
for len in 2..=16 {
let mut connected = vec![vec![false; len]; len];
for i in 0..len {
connected[i][i] = true;
for (notify_target, wait_source) in round_pairs(i, len) {
connected[i][notify_target] = true;
connected[i][wait_source] = true;
}
}
for k in 0..len {
for i in 0..len {
for j in 0..len {
if connected[i][k] && connected[k][j] {
connected[i][j] = true;
}
}
}
}
for i in 0..len {
for j in 0..len {
assert!(
connected[i][j],
"nodes {i} and {j} not transitively connected in len={len}"
);
}
}
}
}
#[test]
fn indices_within_bounds() {
for len in 1..=32 {
for idx in 0..len {
for (right, left) in round_pairs(idx, len) {
assert!(right < len, "right={right} out of bounds for len={len}");
assert!(left < len, "left={left} out of bounds for len={len}");
}
}
}
}
#[test]
fn never_notifies_self() {
for len in 2..=16 {
for idx in 0..len {
for (right, left) in round_pairs(idx, len) {
assert_ne!(right, idx, "node {idx} notifies itself in len={len}");
assert_ne!(left, idx, "node {idx} waits for itself in len={len}");
}
}
}
}
}