use super::*;
use utils::*;
use zmq::*;
use derive_builder::Builder;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::time::Duration;
use tokio::sync::Notify;
use tokio::sync::OnceCell;
use tokio::sync::oneshot;
use tokio::time::sleep;
#[derive(Builder, Clone, Debug, Default)]
pub struct KvbmLeaderNumBlocksConfig {
#[builder(default = "0.0")]
pub cache_size_in_gb: f64,
#[builder(default = "0")]
pub num_blocks_overriden: usize,
}
fn compute_num_blocks(
num_blocks_config: &KvbmLeaderNumBlocksConfig,
bytes_per_block: usize,
) -> usize {
if num_blocks_config.num_blocks_overriden > 0 {
num_blocks_config.num_blocks_overriden
} else {
((num_blocks_config.cache_size_in_gb * 1_000_000_000.0) / bytes_per_block as f64) as usize
}
}
#[derive(Builder, Clone, Debug)]
pub struct KvbmLeaderConfig {
#[builder(default = "1")]
world_size: usize,
#[builder(default = "120")]
leader_init_timeout_secs: u64,
#[builder(default = "KvbmLeaderNumBlocksConfig::default()")]
host_blocks_config: KvbmLeaderNumBlocksConfig,
#[builder(default = "KvbmLeaderNumBlocksConfig::default()")]
disk_blocks_config: KvbmLeaderNumBlocksConfig,
#[builder(default = "String::from(\"tcp://127.0.0.1:56001\")")]
leader_pub_url: String,
#[builder(default = "String::from(\"tcp://127.0.0.1:56002\")")]
leader_ack_url: String,
}
impl KvbmLeaderConfig {
pub fn builder() -> KvbmLeaderConfigBuilder {
KvbmLeaderConfigBuilder::default()
}
pub fn sanity_check(&self) -> anyhow::Result<()> {
if self.leader_pub_url == self.leader_ack_url {
anyhow::bail!(
"leader_pub_url and leader_ack_url must differ (same endpoint would fail to bind)."
);
}
let cpu = &self.host_blocks_config;
let disk = &self.disk_blocks_config;
let cpu_configured = cpu.num_blocks_overriden > 0 || cpu.cache_size_in_gb > 0.0;
let disk_configured = disk.num_blocks_overriden > 0 || disk.cache_size_in_gb > 0.0;
if !cpu_configured && !disk_configured {
panic!(
"KVBM Configuration Error: At least one cache tier must be configured.\n\
\n\
Configure CPU cache (G2) for CPU memory offloading:\n\
• DYN_KVBM_CPU_CACHE_GB=<size_in_gb> (e.g., DYN_KVBM_CPU_CACHE_GB=4)\n\
• DYN_KVBM_CPU_CACHE_OVERRIDE_NUM_BLOCKS=<num_blocks> (e.g., DYN_KVBM_CPU_CACHE_OVERRIDE_NUM_BLOCKS=1000)\n\
\n\
OR configure disk cache (G3) for direct GPU->Disk offloading:\n\
• DYN_KVBM_DISK_CACHE_GB=<size_in_gb> (e.g., DYN_KVBM_DISK_CACHE_GB=8)\n\
• DYN_KVBM_DISK_CACHE_OVERRIDE_NUM_BLOCKS=<num_blocks>\n\
\n\
Note: If only disk cache is configured, KVBM will offload directly from GPU (G1) to Disk (G3), bypassing CPU memory (G2)."
);
}
Ok(())
}
}
#[derive(Debug, Default)]
pub struct KvbmLeaderState {
pub num_device_blocks: Arc<AtomicUsize>,
pub num_host_blocks: Arc<AtomicUsize>,
pub num_disk_blocks: Arc<AtomicUsize>,
pub workers_allocation_ready: Arc<AtomicBool>,
pub workers_ready_notify: Arc<Notify>,
}
pub struct KvbmLeader {
state: Arc<KvbmLeaderState>,
zmq_leader: Arc<OnceCell<ZmqActiveMessageLeader>>,
config: KvbmLeaderConfig,
}
impl KvbmLeader {
pub async fn new(config: KvbmLeaderConfig) -> anyhow::Result<Self> {
let leader_sockets = new_leader_sockets(&config.leader_pub_url, &config.leader_ack_url)?;
let leader = Self {
state: Arc::new(KvbmLeaderState::default()),
zmq_leader: Arc::new(tokio::sync::OnceCell::new()),
config,
};
let cancel_token = tokio_util::sync::CancellationToken::new();
leader.spawn_zmq_task(leader_sockets, cancel_token);
Ok(leader)
}
fn spawn_zmq_task(
&self,
leader_sockets: LeaderSockets,
cancel: tokio_util::sync::CancellationToken,
) {
let cell = self.zmq_leader.clone();
let state = self.state.clone();
let world_size = self.config.world_size;
let timeout = self.config.leader_init_timeout_secs;
let host_cfg = self.config.host_blocks_config.clone();
let disk_cfg = self.config.disk_blocks_config.clone();
let num_device_blocks_cell = state.num_device_blocks.clone();
let num_host_blocks_cell = state.num_host_blocks.clone();
let num_disk_blocks_cell = state.num_disk_blocks.clone();
tokio::spawn(async move {
let res = ZmqActiveMessageLeader::new_with_handshake(
leader_sockets,
world_size,
std::time::Duration::from_secs(timeout),
cancel.clone(),
move |workers: &[WorkerMetadata]| -> LeaderMetadata {
if let Some(min_dev) = workers.iter().map(|w| w.num_device_blocks).min() {
num_device_blocks_cell.store(min_dev, Ordering::Release);
}
let bytes_per_block: usize = workers.iter().map(|w| w.bytes_per_block).sum();
let num_host_blocks = compute_num_blocks(&host_cfg, bytes_per_block);
let num_disk_blocks = compute_num_blocks(&disk_cfg, bytes_per_block);
num_host_blocks_cell.store(num_host_blocks, Ordering::Release);
num_disk_blocks_cell.store(num_disk_blocks, Ordering::Release);
LeaderMetadata {
num_host_blocks,
num_disk_blocks,
}
},
)
.await;
match res {
Ok(zmq) => {
let _ = cell.set(zmq);
state
.workers_allocation_ready
.store(true, Ordering::Release);
state.workers_ready_notify.notify_waiters();
tracing::info!("ZMQ handshake complete; workers allocation ready");
}
Err(e) => {
tracing::error!("ZMQ init/handshake failed: {e:?}");
}
}
});
}
pub async fn transfer_blocks_request(
&self,
request: BlockTransferRequest,
) -> anyhow::Result<oneshot::Receiver<()>> {
let zmq = self
.zmq_leader
.get()
.ok_or_else(|| anyhow::anyhow!("ZMQ leader not ready"))?;
let data = vec![serde_json::to_vec(&request)?];
zmq.broadcast(ZMQ_TRANSFER_BLOCKS_MESSAGE, data).await
}
pub fn num_device_blocks(&self) -> usize {
self.state.num_device_blocks.load(Ordering::Acquire)
}
pub fn num_host_blocks(&self) -> usize {
self.state.num_host_blocks.load(Ordering::Acquire)
}
pub fn num_disk_blocks(&self) -> usize {
self.state.num_disk_blocks.load(Ordering::Acquire)
}
pub async fn wait_worker_sync_ready(&self) -> bool {
if self.state.workers_allocation_ready.load(Ordering::Acquire) {
return true;
}
let notified = self.state.workers_ready_notify.notified();
tokio::select! {
_ = notified => true,
_ = sleep(Duration::from_secs(self.config.leader_init_timeout_secs)) => false,
}
}
}