use crate::client::NexarClient;
use crate::cluster::seed::PendingJoin;
use crate::config::NexarConfig;
use crate::error::{NexarError, Result};
use crate::protocol::NexarMessage;
use crate::transport::tls::make_client_config_mtls;
use crate::transport::{PeerConnection, TransportListener};
use crate::types::{Priority, Rank};
use std::net::SocketAddr;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, Mutex as StdMutex};
use tokio::sync::{Mutex, broadcast};
#[derive(Debug, Clone)]
pub struct ElasticConfig {
pub enabled: bool,
pub min_world_size: u32,
pub max_world_size: u32,
}
impl Default for ElasticConfig {
fn default() -> Self {
Self {
enabled: true,
min_world_size: 1,
max_world_size: 0,
}
}
}
#[derive(Debug, Clone)]
pub struct ElasticEvent {
pub old_world_size: u32,
pub new_world_size: u32,
pub new_rank: Rank,
pub joined: Vec<Rank>,
pub left: Vec<Rank>,
}
pub struct ElasticBootstrap {
pub managers: Vec<ElasticManager>,
pub seed_addr: SocketAddr,
}
pub struct ElasticManager {
client: Arc<Mutex<NexarClient>>,
pending_joins: Arc<StdMutex<Vec<PendingJoin>>>,
pending_leaves: Arc<StdMutex<Vec<Rank>>>,
checkpoint_epoch: AtomicU64,
event_tx: broadcast::Sender<ElasticEvent>,
config: ElasticConfig,
nexar_config: Arc<NexarConfig>,
ca_cert: Vec<u8>,
my_cert: Vec<u8>,
my_key: Vec<u8>,
seed_addr: Option<SocketAddr>,
new_worker_listeners: Arc<StdMutex<Vec<(Rank, TransportListener)>>>,
}
impl Clone for ElasticManager {
fn clone(&self) -> Self {
Self {
client: Arc::clone(&self.client),
pending_joins: Arc::clone(&self.pending_joins),
pending_leaves: Arc::clone(&self.pending_leaves),
checkpoint_epoch: AtomicU64::new(self.checkpoint_epoch.load(Ordering::Relaxed)),
event_tx: self.event_tx.clone(),
config: self.config.clone(),
nexar_config: Arc::clone(&self.nexar_config),
ca_cert: self.ca_cert.clone(),
my_cert: self.my_cert.clone(),
my_key: self.my_key.clone(),
seed_addr: self.seed_addr,
new_worker_listeners: Arc::clone(&self.new_worker_listeners),
}
}
}
impl ElasticManager {
#[allow(clippy::too_many_arguments)]
pub fn new(
client: NexarClient,
config: ElasticConfig,
nexar_config: NexarConfig,
ca_cert: Vec<u8>,
my_cert: Vec<u8>,
my_key: Vec<u8>,
pending_joins: Arc<StdMutex<Vec<PendingJoin>>>,
seed_addr: Option<SocketAddr>,
) -> Self {
let (event_tx, _) = broadcast::channel(16);
Self {
client: Arc::new(Mutex::new(client)),
pending_joins,
pending_leaves: Arc::new(StdMutex::new(Vec::new())),
checkpoint_epoch: AtomicU64::new(0),
event_tx,
config,
nexar_config: Arc::new(nexar_config),
ca_cert,
my_cert,
my_key,
seed_addr,
new_worker_listeners: Arc::new(StdMutex::new(Vec::new())),
}
}
pub fn subscribe(&self) -> broadcast::Receiver<ElasticEvent> {
self.event_tx.subscribe()
}
pub fn client(&self) -> Arc<Mutex<NexarClient>> {
Arc::clone(&self.client)
}
pub async fn elastic_checkpoint(&self) -> Result<Option<ElasticEvent>> {
let joining: Vec<PendingJoin> = {
let mut pj = self.pending_joins.lock().unwrap_or_else(|p| p.into_inner());
std::mem::take(&mut *pj)
};
let leaving: Vec<Rank> = {
let mut pl = self
.pending_leaves
.lock()
.unwrap_or_else(|p| p.into_inner());
std::mem::take(&mut *pl)
};
if joining.is_empty() && leaving.is_empty() {
let client = self.client.lock().await;
client.barrier().await?;
return Ok(None);
}
let (epoch, rank, old_world_size, timeout) = {
let client = self.client.lock().await;
let epoch = self.checkpoint_epoch.fetch_add(1, Ordering::Relaxed);
(
epoch,
client.rank(),
client.world_size(),
self.nexar_config.elastic_checkpoint_timeout,
)
};
let (joining_info, leaving_ranks, new_world_size) = if rank == 0 {
self.coordinate_as_rank0(epoch, old_world_size, timeout, &joining, &leaving)
.await?
} else {
self.participate_as_follower(epoch, timeout).await?
};
let event = self
.apply_resize(
old_world_size,
new_world_size,
rank,
&joining_info,
&leaving_ranks,
)
.await?;
let _ = self.event_tx.send(event.clone());
Ok(Some(event))
}
async fn coordinate_as_rank0(
&self,
epoch: u64,
old_world_size: u32,
timeout: std::time::Duration,
joining: &[PendingJoin],
leaving: &[Rank],
) -> Result<(Vec<(Rank, String)>, Vec<Rank>, u32)> {
let client = self.client.lock().await;
for src_rank in 1..old_world_size {
let msg = tokio::time::timeout(timeout, client.recv_control(src_rank))
.await
.map_err(|_| NexarError::ElasticTimeout {
epoch,
timeout_ms: timeout.as_millis() as u64,
})??;
match msg {
NexarMessage::ElasticCheckpoint { epoch: e } if e == epoch => {}
other => {
return Err(NexarError::Elastic(format!(
"expected ElasticCheckpoint(epoch={epoch}), got {other:?}"
)));
}
}
}
let joining_info: Vec<(Rank, String)> = joining
.iter()
.map(|pj| (pj.rank, pj.listen_addr.clone()))
.collect();
let new_world_size = old_world_size + joining.len() as u32 - leaving.len() as u32;
let ack = NexarMessage::ElasticCheckpointAck {
epoch,
joining: joining_info.clone(),
leaving: leaving.to_vec(),
new_world_size,
};
for dest_rank in 1..old_world_size {
if leaving.contains(&dest_rank) {
continue;
}
let peer = client.peer(dest_rank)?;
peer.send_message(&ack, Priority::Critical).await?;
}
Ok((joining_info, leaving.to_vec(), new_world_size))
}
async fn participate_as_follower(
&self,
epoch: u64,
timeout: std::time::Duration,
) -> Result<(Vec<(Rank, String)>, Vec<Rank>, u32)> {
let client = self.client.lock().await;
let checkpoint = NexarMessage::ElasticCheckpoint { epoch };
let peer0 = client.peer(0)?;
peer0.send_message(&checkpoint, Priority::Critical).await?;
let ack_msg = tokio::time::timeout(timeout, client.recv_control(0))
.await
.map_err(|_| NexarError::ElasticTimeout {
epoch,
timeout_ms: timeout.as_millis() as u64,
})??;
match ack_msg {
NexarMessage::ElasticCheckpointAck {
epoch: e,
joining,
leaving,
new_world_size,
} if e == epoch => Ok((joining, leaving, new_world_size)),
other => Err(NexarError::Elastic(format!(
"expected ElasticCheckpointAck(epoch={epoch}), got {other:?}"
))),
}
}
async fn apply_resize(
&self,
old_world_size: u32,
new_world_size: u32,
rank: Rank,
joined: &[(Rank, String)],
left: &[Rank],
) -> Result<ElasticEvent> {
let mut client = self.client.lock().await;
if !joined.is_empty() {
let ca_der = rustls::pki_types::CertificateDer::from(self.ca_cert.clone());
let my_cert_der = rustls::pki_types::CertificateDer::from(self.my_cert.clone());
let my_key_der = rustls::pki_types::PrivateKeyDer::try_from(self.my_key.clone())
.map_err(|e| NexarError::Tls(format!("parse private key for mesh connect: {e}")))?;
let client_config = make_client_config_mtls(my_cert_der, my_key_der, &ca_der)?;
let mut new_peers: Vec<(Rank, Arc<PeerConnection>)> = Vec::new();
for &(new_rank, ref addr_str) in joined {
let addr: SocketAddr = addr_str.parse().map_err(|e| {
NexarError::Elastic(format!(
"invalid listen address '{addr_str}' for rank {new_rank}: {e}"
))
})?;
let bind_addr: SocketAddr = if addr.is_ipv4() {
"127.0.0.1:0"
} else {
"[::1]:0"
}
.parse()
.expect("hardcoded socket addr");
let mut endpoint = quinn::Endpoint::client(bind_addr).map_err(|e| {
NexarError::transport_with_source("bind client for new peer", e)
})?;
endpoint.set_default_client_config(client_config.clone());
let conn = endpoint
.connect(addr, "localhost")
.map_err(|e| {
NexarError::transport_with_source(
format!("connect to new rank {new_rank}"),
e,
)
})?
.await
.map_err(|e| {
NexarError::transport_with_source(
format!("handshake with new rank {new_rank}"),
e,
)
})?;
let peer = Arc::new(PeerConnection::new(new_rank, conn));
peer.warm_stream_pool().await;
new_peers.push((new_rank, peer));
}
let rebuilt = client.rebuild_adding(new_peers).await?;
*client = rebuilt;
}
if !left.is_empty() {
let left_ranks: Vec<Rank> = left.to_vec();
let rebuilt = client.rebuild_excluding(&left_ranks).await?;
*client = rebuilt;
}
Ok(ElasticEvent {
old_world_size,
new_world_size,
new_rank: rank,
joined: joined.iter().map(|(r, _)| *r).collect(),
left: left.to_vec(),
})
}
pub fn remove_node(&self, rank: Rank) {
self.pending_leaves
.lock()
.unwrap_or_else(|p| p.into_inner())
.push(rank);
}
pub async fn add_nodes(&self, count: u32) -> Result<Vec<PendingJoin>> {
let seed_addr = self.seed_addr.ok_or_else(|| {
NexarError::Elastic("no seed address configured for add_nodes".into())
})?;
let mut new_joins = Vec::new();
for _ in 0..count {
let worker = crate::cluster::WorkerNode::connect(seed_addr).await?;
let ca_der = rustls::pki_types::CertificateDer::from(worker.ca_cert.clone());
let cert_der = rustls::pki_types::CertificateDer::from(worker.node_cert.clone());
let key_der = rustls::pki_types::PrivateKeyDer::try_from(worker.node_key.clone())
.map_err(|e| NexarError::Tls(format!("parse new node private key: {e}")))?;
let bind_addr: SocketAddr = if seed_addr.is_ipv4() {
"127.0.0.1:0"
} else {
"[::1]:0"
}
.parse()
.expect("hardcoded socket addr");
let listener =
TransportListener::bind_with_mtls(bind_addr, cert_der, key_der, &ca_der)?;
let listen_addr = listener.local_addr().to_string();
let pj = PendingJoin {
rank: worker.rank,
listen_addr,
};
new_joins.push(pj.clone());
self.pending_joins
.lock()
.unwrap_or_else(|p| p.into_inner())
.push(pj);
let mut listeners = self
.new_worker_listeners
.lock()
.unwrap_or_else(|p| p.into_inner());
listeners.push((worker.rank, listener));
}
Ok(new_joins)
}
}