use crate::cluster::HealthMonitor;
use crate::cluster::sparse::RoutingTable;
use crate::config::NexarConfig;
use crate::device::DeviceAdapter;
use crate::error::{NexarError, Result};
use crate::protocol::NexarMessage;
use crate::rpc::RpcDispatcher;
use crate::rpc::registry::{RpcHandler, RpcRegistry};
use crate::transport::PeerConnection;
use crate::transport::buffer_pool::{BufferPool, PoolProfile, PooledBuf};
use crate::transport::relay::RelayDeliveries;
use crate::transport::router::PeerRouter;
use crate::types::{Priority, Rank};
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use tokio::sync::{Mutex, RwLock, watch};
type TaggedReceiverMap = HashMap<(Rank, u64), Arc<Mutex<tokio::sync::mpsc::Receiver<PooledBuf>>>>;
type MonitorParts = (
Arc<watch::Sender<Vec<Rank>>>,
watch::Receiver<Vec<Rank>>,
tokio::task::JoinHandle<()>,
);
pub(super) enum RawRecvSource {
Router,
Comm(HashMap<Rank, Mutex<tokio::sync::mpsc::Receiver<PooledBuf>>>),
}
pub struct NexarClient {
pub(super) rank: Rank,
pub(super) world_size: u32,
pub(super) comm_id: u64,
pub(crate) peers: HashMap<Rank, Arc<PeerConnection>>,
pub(super) routers: HashMap<Rank, PeerRouter>,
pub(super) raw_recv: RawRecvSource,
pub(super) _router_handles: Vec<tokio::task::JoinHandle<Result<()>>>,
pub(super) adapter: Arc<dyn DeviceAdapter>,
pub(super) _pool: Arc<BufferPool>,
pub(super) barrier_epoch: AtomicU64,
pub(super) rpc_registry: Arc<RwLock<RpcRegistry>>,
pub(super) rpc_req_id: AtomicU64,
pub(super) split_generation: AtomicU64,
pub(super) rank_map: HashMap<Rank, Rank>,
pub(super) collective_tag: AtomicU64,
pub(super) tagged_receivers: Mutex<TaggedReceiverMap>,
pub(crate) config: Arc<NexarConfig>,
pub(super) failure_tx: Arc<watch::Sender<Vec<Rank>>>,
pub(super) failure_rx: watch::Receiver<Vec<Rank>>,
pub(super) _monitor_handle: Option<tokio::task::JoinHandle<()>>,
pub(crate) routing_table: Option<Arc<RoutingTable>>,
pub(crate) relay_deliveries: Option<Arc<RelayDeliveries>>,
pub(super) _relay_handles: Vec<tokio::task::JoinHandle<()>>,
pub(crate) _endpoints: Vec<quinn::Endpoint>,
}
impl NexarClient {
pub fn new(
rank: Rank,
world_size: u32,
peers: HashMap<Rank, PeerConnection>,
adapter: Arc<dyn DeviceAdapter>,
) -> Self {
Self::new_with_config(
rank,
world_size,
peers,
adapter,
PoolProfile::Training,
NexarConfig::from_env(),
)
}
pub fn new_with_profile(
rank: Rank,
world_size: u32,
peers: HashMap<Rank, PeerConnection>,
adapter: Arc<dyn DeviceAdapter>,
profile: PoolProfile,
) -> Self {
Self::new_with_config(
rank,
world_size,
peers,
adapter,
profile,
NexarConfig::from_env(),
)
}
pub fn new_with_config(
rank: Rank,
world_size: u32,
peers: HashMap<Rank, PeerConnection>,
adapter: Arc<dyn DeviceAdapter>,
profile: PoolProfile,
config: NexarConfig,
) -> Self {
let pool = BufferPool::with_profile(profile);
Self::build(rank, world_size, peers, adapter, pool, config)
}
pub fn new_with_pool(
rank: Rank,
world_size: u32,
peers: HashMap<Rank, PeerConnection>,
adapter: Arc<dyn DeviceAdapter>,
pool: Arc<BufferPool>,
) -> Self {
Self::new_with_pool_and_config(
rank,
world_size,
peers,
adapter,
pool,
NexarConfig::from_env(),
)
}
pub fn new_with_pool_and_config(
rank: Rank,
world_size: u32,
peers: HashMap<Rank, PeerConnection>,
adapter: Arc<dyn DeviceAdapter>,
pool: Arc<BufferPool>,
config: NexarConfig,
) -> Self {
Self::build(rank, world_size, peers, adapter, pool, config)
}
fn build(
rank: Rank,
world_size: u32,
peers: HashMap<Rank, PeerConnection>,
adapter: Arc<dyn DeviceAdapter>,
pool: Arc<BufferPool>,
config: NexarConfig,
) -> Self {
let mut peer_arcs: HashMap<Rank, Arc<PeerConnection>> = HashMap::new();
let mut routers: HashMap<Rank, PeerRouter> = HashMap::new();
let mut handles = Vec::new();
for (peer_rank, peer_conn) in peers {
let conn_clone = peer_conn.conn.clone();
let (router, handle) = PeerRouter::spawn(peer_rank, conn_clone, Arc::clone(&pool));
peer_arcs.insert(peer_rank, Arc::new(peer_conn));
routers.insert(peer_rank, router);
handles.push(handle);
}
let (failure_tx, failure_rx, monitor_handle) = Self::spawn_monitor(&config, &peer_arcs);
Self {
rank,
world_size,
comm_id: 0,
peers: peer_arcs,
routers,
raw_recv: RawRecvSource::Router,
_router_handles: handles,
adapter,
_pool: pool,
barrier_epoch: AtomicU64::new(0),
rpc_registry: Arc::new(RwLock::new(RpcRegistry::new())),
rpc_req_id: AtomicU64::new(0),
split_generation: AtomicU64::new(0),
rank_map: HashMap::new(),
collective_tag: AtomicU64::new(1),
tagged_receivers: Mutex::new(HashMap::new()),
config: Arc::new(config),
failure_tx,
failure_rx,
_monitor_handle: Some(monitor_handle),
routing_table: None,
relay_deliveries: None,
_relay_handles: Vec::new(),
_endpoints: Vec::new(),
}
}
fn spawn_monitor(
config: &NexarConfig,
peers: &HashMap<Rank, Arc<PeerConnection>>,
) -> MonitorParts {
let (failure_tx, failure_rx) = watch::channel(Vec::new());
let failure_tx = Arc::new(failure_tx);
let monitor =
HealthMonitor::with_timeout(config.heartbeat_interval, config.heartbeat_timeout);
let monitor_peers: Vec<_> = peers.iter().map(|(r, p)| (*r, Arc::clone(p))).collect();
let handle = monitor.start_monitoring(monitor_peers, Arc::clone(&failure_tx));
(failure_tx, failure_rx, handle)
}
pub(crate) fn next_barrier_epoch(&self) -> u64 {
self.barrier_epoch.fetch_add(1, Ordering::Relaxed)
}
pub async fn register_rpc(&self, fn_id: u16, handler: RpcHandler) {
let mut reg = self.rpc_registry.write().await;
reg.register(fn_id, handler);
}
pub fn rpc_dispatcher(&self) -> RpcDispatcher {
RpcDispatcher::new(Arc::clone(&self.rpc_registry))
}
pub fn rank(&self) -> Rank {
self.rank
}
pub fn world_size(&self) -> u32 {
self.world_size
}
pub fn comm_id(&self) -> u64 {
self.comm_id
}
pub fn adapter(&self) -> &dyn DeviceAdapter {
self.adapter.as_ref()
}
pub fn config(&self) -> &NexarConfig {
&self.config
}
pub fn peer(&self, rank: Rank) -> Result<&Arc<PeerConnection>> {
self.peers
.get(&rank)
.ok_or(NexarError::UnknownPeer { rank })
}
pub(super) fn resolve_rank(&self, rank: Rank) -> Rank {
self.rank_map.get(&rank).copied().unwrap_or(rank)
}
pub async unsafe fn send(
&self,
data_ptr: u64,
size: usize,
dest: Rank,
tag: u32,
) -> Result<()> {
if dest >= self.world_size {
return Err(NexarError::InvalidRank {
rank: dest,
world_size: self.world_size,
});
}
let data = unsafe { self.adapter.stage_for_send(data_ptr, size)? };
let msg = NexarMessage::Data {
tag,
src_rank: self.rank,
payload: data,
};
self.send_message_to(dest, &msg, Priority::Bulk).await
}
pub async unsafe fn recv(&self, buf_ptr: u64, size: usize, src: Rank, tag: u32) -> Result<()> {
if src >= self.world_size {
return Err(NexarError::InvalidRank {
rank: src,
world_size: self.world_size,
});
}
let msg = if !self.has_direct_peer(src) && self.relay_deliveries.is_some() {
self.recv_control_from(src).await?
} else {
self.recv_data_message(src).await?
};
match msg {
NexarMessage::Data {
tag: recv_tag,
payload,
..
} => {
if recv_tag != tag {
return Err(NexarError::DecodeFailed(format!(
"tag mismatch: expected {tag}, got {recv_tag}"
)));
}
if payload.len() != size {
return Err(NexarError::BufferSizeMismatch {
expected: size,
actual: payload.len(),
});
}
unsafe { self.adapter.receive_to_device(&payload, buf_ptr)? };
Ok(())
}
other => Err(NexarError::DecodeFailed(format!(
"expected Data message, got {other:?}"
))),
}
}
pub fn failure_watch(&self) -> watch::Receiver<Vec<Rank>> {
self.failure_rx.clone()
}
pub(crate) fn next_collective_tag(&self) -> u64 {
self.collective_tag.fetch_add(1, Ordering::Relaxed)
}
pub fn close(&self) {
for peer in self.peers.values() {
peer.conn.close(0u32.into(), b"closed");
}
}
pub fn has_direct_peer(&self, rank: Rank) -> bool {
self.peers.contains_key(&rank)
}
pub fn is_sparse(&self) -> bool {
self.routing_table.is_some()
}
}