use crate::client::NexarClient;
use crate::collective::helpers::ceil_log2;
use crate::error::{NexarError, Result};
use crate::protocol::NexarMessage;
use crate::types::Priority;
use std::time::Duration;
const DISSEMINATION_THRESHOLD: u32 = 5;
pub async fn barrier(client: &NexarClient, timeout: Duration) -> Result<()> {
let world = client.world_size();
if world <= 1 {
return Ok(());
}
if world < DISSEMINATION_THRESHOLD {
two_phase_barrier(client, timeout).await
} else {
dissemination_barrier(client, timeout).await
}
}
async fn two_phase_barrier(client: &NexarClient, timeout: Duration) -> Result<()> {
let epoch = client.next_barrier_epoch();
let rank = client.rank();
let world = client.world_size();
let comm_id = client.comm_id();
if rank == 0 {
for r in 1..world {
let msg = tokio::time::timeout(timeout, client.recv_control_from(r))
.await
.map_err(|_| NexarError::CollectiveFailed {
operation: "barrier",
rank: r,
reason: format!(
"timed out waiting for Barrier(epoch={epoch}) after {}ms",
timeout.as_millis()
),
})?
.map_err(|e| NexarError::CollectiveFailed {
operation: "barrier",
rank: r,
reason: e.to_string(),
})?;
match msg {
NexarMessage::Barrier {
epoch: e,
comm_id: c,
} if e == epoch && c == comm_id => {}
other => {
return Err(NexarError::CollectiveFailed {
operation: "barrier",
rank: r,
reason: format!(
"expected Barrier(epoch={epoch}, comm_id={comm_id}), got {other:?}"
),
});
}
}
}
let ack = NexarMessage::BarrierAck { epoch, comm_id };
for r in 1..world {
client
.send_message_to(r, &ack, Priority::Critical)
.await
.map_err(|e| NexarError::CollectiveFailed {
operation: "barrier",
rank: r,
reason: e.to_string(),
})?;
}
} else {
let barrier_msg = NexarMessage::Barrier { epoch, comm_id };
client
.send_message_to(0, &barrier_msg, Priority::Critical)
.await
.map_err(|e| NexarError::CollectiveFailed {
operation: "barrier",
rank: 0,
reason: e.to_string(),
})?;
let ack = tokio::time::timeout(timeout, client.recv_control_from(0))
.await
.map_err(|_| NexarError::CollectiveFailed {
operation: "barrier",
rank: 0,
reason: format!(
"timed out waiting for BarrierAck(epoch={epoch}) after {}ms",
timeout.as_millis()
),
})?
.map_err(|e| NexarError::CollectiveFailed {
operation: "barrier",
rank: 0,
reason: e.to_string(),
})?;
match ack {
NexarMessage::BarrierAck {
epoch: e,
comm_id: c,
} if e == epoch && c == comm_id => {}
other => {
return Err(NexarError::CollectiveFailed {
operation: "barrier",
rank: 0,
reason: format!(
"expected BarrierAck(epoch={epoch}, comm_id={comm_id}), got {other:?}"
),
});
}
}
}
Ok(())
}
async fn dissemination_barrier(client: &NexarClient, timeout: Duration) -> Result<()> {
let epoch = client.next_barrier_epoch();
let rank = client.rank();
let world = client.world_size();
let comm_id = client.comm_id();
let num_rounds = ceil_log2(world);
for round in 0..num_rounds {
let distance = 1u32 << round;
let send_to = (rank + distance) % world;
let recv_from = (rank + world - distance) % world;
let msg = NexarMessage::Barrier { epoch, comm_id };
let send_fut = async {
client
.send_message_to(send_to, &msg, Priority::Critical)
.await
.map_err(|e| NexarError::CollectiveFailed {
operation: "barrier",
rank: send_to,
reason: e.to_string(),
})
};
let recv_fut = async {
let received = tokio::time::timeout(timeout, client.recv_control_from(recv_from))
.await
.map_err(|_| NexarError::CollectiveFailed {
operation: "barrier",
rank: recv_from,
reason: format!(
"timed out in dissemination round {round} after {}ms",
timeout.as_millis()
),
})?
.map_err(|e| NexarError::CollectiveFailed {
operation: "barrier",
rank: recv_from,
reason: e.to_string(),
})?;
match received {
NexarMessage::Barrier {
epoch: e,
comm_id: c,
} if e == epoch && c == comm_id => Ok(()),
other => Err(NexarError::CollectiveFailed {
operation: "barrier",
rank: recv_from,
reason: format!(
"expected Barrier(epoch={epoch}, comm_id={comm_id}), got {other:?}"
),
}),
}
};
tokio::try_join!(send_fut, recv_fut)?;
}
Ok(())
}