use std::collections::HashSet;
use std::fmt;
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
use std::time::Duration;
use anyhow::Result;
use async_trait::async_trait;
use rmp_serde as rmps;
use serde::Deserialize;
use serde::Serialize;
use serde::de::{self, Deserializer, IgnoredAny, MapAccess, SeqAccess, Visitor};
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
use zeromq::{Socket, SocketRecv, SubSocket};
use dynamo_runtime::traits::DistributedRuntimeProvider;
use dynamo_runtime::transports::event_plane::EventPublisher;
use dynamo_runtime::{
component::{Component, Namespace},
transports::nats::{NatsQueue, Slug},
};
fn create_kv_stream_name(component: &Component, subject: &str) -> String {
Slug::slugify(&format!(
"namespace.{}.component.{}.{}",
component.namespace().name(),
component.name(),
subject
))
.to_string()
.replace("_", "-")
}
use crate::kv_router::{
KV_EVENT_SUBJECT, KV_METRICS_SUBJECT, WORKER_KV_INDEXER_BUFFER_SIZE,
indexer::{KvIndexerMetrics, LocalKvIndexer},
protocols::*,
worker_query::start_worker_kv_query_endpoint,
};
use dynamo_runtime::config::environment_names::nats as env_nats;
const INITIAL_BACKOFF_MS: u64 = 10;
const MAX_BACKOFF_MS: u64 = 5000;
const MAX_CONSECUTIVE_ERRORS: u32 = 10;
const MAX_BACKOFF_EXPONENT: u32 = 8;
pub enum KvEventSourceConfig {
Zmq { endpoint: String, topic: String },
}
enum KvEventSource {
Zmq {
zmq_handle: tokio::task::JoinHandle<()>,
},
}
impl KvEventSource {
fn start(
component: Component,
kv_block_size: u32,
source_config: KvEventSourceConfig,
cancellation_token: CancellationToken,
tx: mpsc::UnboundedSender<KvCacheEvent>,
next_event_id: Arc<AtomicU64>,
) -> Result<Self> {
match source_config {
KvEventSourceConfig::Zmq { endpoint, topic } => {
let zmq_handle = component
.drt()
.runtime()
.secondary()
.spawn(start_zmq_listener(
endpoint,
topic,
tx,
cancellation_token.clone(),
kv_block_size,
next_event_id,
));
Ok(KvEventSource::Zmq { zmq_handle })
}
}
}
fn shutdown(&self) {
match self {
KvEventSource::Zmq { zmq_handle } => {
zmq_handle.abort();
}
}
}
}
pub struct KvEventPublisher {
kv_block_size: u32,
source: Option<KvEventSource>,
cancellation_token: CancellationToken,
tx: mpsc::UnboundedSender<KvCacheEvent>,
next_event_id: Arc<AtomicU64>,
}
impl KvEventPublisher {
pub fn new(
component: Component,
kv_block_size: u32,
source_config: Option<KvEventSourceConfig>,
) -> Result<Self> {
Self::new_with_local_indexer(component, kv_block_size, source_config, false, 0)
}
pub fn new_with_local_indexer(
component: Component,
kv_block_size: u32,
source_config: Option<KvEventSourceConfig>,
enable_local_indexer: bool,
dp_rank: DpRank,
) -> Result<Self> {
let cancellation_token = CancellationToken::new();
let (tx, rx) = mpsc::unbounded_channel::<KvCacheEvent>();
let worker_id = component.drt().connection_id();
let component_name = component.name();
tracing::info!(
"Initializing KvEventPublisher for worker {worker_id} in component {component_name}"
);
if enable_local_indexer {
tracing::info!(
"LocalKvIndexer enabled for worker {worker_id} in component {component_name}"
);
}
let next_event_id = Arc::new(AtomicU64::new(0));
let mut source = None;
if let Some(config) = source_config {
source = Some(KvEventSource::start(
component.clone(),
kv_block_size,
config,
cancellation_token.clone(),
tx.clone(),
next_event_id.clone(),
)?);
}
let local_indexer = if enable_local_indexer {
let metrics = Arc::new(KvIndexerMetrics::new_unregistered());
Some(Arc::new(LocalKvIndexer::new(
cancellation_token.clone(),
kv_block_size,
metrics,
WORKER_KV_INDEXER_BUFFER_SIZE,
)))
} else {
None
};
let _local_indexer_query_handle = local_indexer.as_ref().map(|local_indexer_ref| {
let component = component.clone();
let local_indexer = local_indexer_ref.clone();
component
.drt()
.runtime()
.secondary()
.spawn(start_worker_kv_query_endpoint(
component,
worker_id,
dp_rank,
local_indexer,
))
});
let cancellation_token_clone = cancellation_token.clone();
let local_indexer_clone = local_indexer.clone();
if enable_local_indexer {
tracing::info!("Using event plane for KV event publishing (local_indexer mode)");
let component_clone = component.clone();
component.drt().runtime().secondary().spawn(async move {
let event_publisher =
match EventPublisher::for_component(&component_clone, KV_EVENT_SUBJECT).await {
Ok(publisher) => publisher,
Err(e) => {
tracing::error!("Failed to create event publisher: {}", e);
return;
}
};
start_event_processor(
event_publisher,
worker_id,
cancellation_token_clone,
rx,
local_indexer_clone,
)
.await
});
} else {
let stream_name = create_kv_stream_name(&component, KV_EVENT_SUBJECT);
let nats_server = std::env::var(env_nats::NATS_SERVER)
.unwrap_or_else(|_| "nats://localhost:4222".to_string());
let mut nats_queue = NatsQueue::new_without_consumer(
stream_name,
nats_server,
std::time::Duration::from_secs(60), );
component.drt().runtime().secondary().spawn(async move {
if let Err(e) = nats_queue.connect().await {
tracing::error!("Failed to connect NatsQueue: {e}");
return;
}
start_event_processor_jetstream(
nats_queue,
worker_id,
cancellation_token_clone,
rx,
local_indexer_clone,
)
.await
});
}
Ok(Self {
kv_block_size,
source,
cancellation_token,
tx,
next_event_id,
})
}
pub fn publish(&self, event: KvCacheEvent) -> Result<(), mpsc::error::SendError<KvCacheEvent>> {
self.tx.send(event)
}
pub fn next_event_id(&self) -> u64 {
self.next_event_id.fetch_add(1, Ordering::SeqCst)
}
pub fn kv_block_size(&self) -> u32 {
self.kv_block_size
}
pub fn shutdown(&mut self) {
if !self.cancellation_token.is_cancelled() {
self.cancellation_token.cancel();
}
if let Some(source) = self.source.take() {
source.shutdown();
}
}
}
impl Drop for KvEventPublisher {
fn drop(&mut self) {
self.shutdown();
}
}
#[async_trait]
trait EventSink: Send + Sync {
async fn publish_event(&self, event: &RouterEvent) -> Result<()>;
}
#[async_trait]
impl EventSink for EventPublisher {
async fn publish_event(&self, event: &RouterEvent) -> Result<()> {
self.publish(event).await
}
}
#[async_trait]
impl EventSink for NatsQueue {
async fn publish_event(&self, event: &RouterEvent) -> Result<()> {
NatsQueue::publish_event(self, KV_EVENT_SUBJECT, event).await
}
}
async fn start_event_processor<P: EventSink + Send + Sync + 'static>(
publisher: P,
worker_id: u64,
cancellation_token: CancellationToken,
mut rx: mpsc::UnboundedReceiver<KvCacheEvent>,
local_indexer: Option<Arc<LocalKvIndexer>>,
) {
loop {
tokio::select! {
_ = cancellation_token.cancelled() => {
tracing::info!("KV Event source received cancellation signal");
break;
}
event = rx.recv() => {
let Some(event) = event else {
tracing::debug!("Event processor channel closed.");
break;
};
tracing::trace!("Event processor for worker_id {} processing event: {:?}", worker_id, event.data);
let router_event = RouterEvent::new(worker_id, event);
if let Some(indexer) = &local_indexer {
if let Err(e) = indexer.apply_event_with_buffer(router_event.clone()).await {
tracing::warn!(
"Failed to send event to local indexer for worker {}: {}",
worker_id,
e
);
}
}
if let Err(e) = publisher.publish_event(&router_event).await {
tracing::error!("Failed to publish event: {}", e);
}
}
}
}
}
async fn start_event_processor_jetstream(
publisher: NatsQueue,
worker_id: u64,
cancellation_token: CancellationToken,
mut rx: mpsc::UnboundedReceiver<KvCacheEvent>,
local_indexer: Option<Arc<LocalKvIndexer>>,
) {
loop {
tokio::select! {
_ = cancellation_token.cancelled() => {
tracing::info!("KV Event source received cancellation signal");
break;
}
event = rx.recv() => {
let Some(event) = event else {
tracing::debug!("Event processor channel closed.");
break;
};
tracing::trace!("Event processor for worker_id {} processing event: {:?}", worker_id, event.data);
let router_event = RouterEvent::new(worker_id, event);
if let Some(indexer) = &local_indexer {
if let Err(e) = indexer.apply_event_with_buffer(router_event.clone()).await {
tracing::warn!(
"Failed to send event to local indexer for worker {}: {}",
worker_id,
e
);
}
}
if let Err(e) = publisher.publish_event(KV_EVENT_SUBJECT, &router_event).await {
tracing::error!("Failed to publish event to NATS JetStream: {}", e);
}
}
}
}
}
fn calculate_backoff_ms(consecutive_errors: u32) -> u64 {
std::cmp::min(
INITIAL_BACKOFF_MS * 2_u64.pow(consecutive_errors.min(MAX_BACKOFF_EXPONENT)),
MAX_BACKOFF_MS,
)
}
pub async fn start_zmq_listener(
zmq_endpoint: String,
zmq_topic: String,
tx: mpsc::UnboundedSender<KvCacheEvent>,
cancellation_token: CancellationToken,
kv_block_size: u32,
next_event_id: Arc<AtomicU64>,
) {
tracing::debug!(
"KVEventPublisher connecting to ZMQ endpoint {} (topic '{}')",
zmq_endpoint,
zmq_topic
);
let warning_count = Arc::new(AtomicU32::new(0));
let mut socket = SubSocket::new();
if let Err(e) = socket.subscribe(&zmq_topic).await {
tracing::error!("Failed to subscribe on ZMQ socket: {}", e);
return;
}
if let Err(e) = socket.connect(&zmq_endpoint).await {
tracing::error!("Failed to connect ZMQ SUB socket to {zmq_endpoint}: {e}");
return;
}
let mut consecutive_errors = 0u32;
#[allow(unused_assignments)]
let mut exit_reason = "unknown";
let mut messages_processed = 0u64;
'main: loop {
tokio::select! {
biased;
_ = cancellation_token.cancelled() => {
tracing::debug!("ZMQ listener received cancellation signal");
exit_reason = "cancellation token cancelled";
break 'main;
}
msg_result = socket.recv() => {
let Ok(msg) = msg_result else {
let e = msg_result.unwrap_err();
consecutive_errors += 1;
if consecutive_errors >= MAX_CONSECUTIVE_ERRORS {
tracing::error!(
error=%e,
consecutive_errors=%consecutive_errors,
"Too many consecutive ZMQ errors, terminating listener"
);
exit_reason = "too many consecutive errors";
break 'main;
}
let backoff_ms = calculate_backoff_ms(consecutive_errors);
tracing::warn!(
error=%e,
consecutive_errors=%consecutive_errors,
backoff_ms=%backoff_ms,
"Error reading from ZMQ socket, applying exponential backoff"
);
tokio::time::sleep(Duration::from_millis(backoff_ms)).await;
continue;
};
consecutive_errors = 0;
let mut frames: Vec<Vec<u8>> = msg.into_vec().into_iter().map(|frame| frame.to_vec()).collect();
if frames.len() != 3 {
tracing::warn!("Received unexpected ZMQ frame count: expected 3, actual {}", frames.len());
continue;
}
let payload = frames.pop().unwrap();
let seq_bytes = frames.pop().unwrap();
if seq_bytes.len() != 8 {
tracing::warn!("Invalid sequence number byte length: expected 8, actual {}", seq_bytes.len());
continue;
}
let engine_seq = u64::from_be_bytes(seq_bytes.try_into().unwrap());
let batch_result = rmps::from_slice::<KvEventBatch>(&payload);
let Ok(batch) = batch_result else {
let e = batch_result.unwrap_err();
tracing::warn!("Failed to decode KVEventBatch msgpack: {e}");
continue;
};
tracing::trace!(
"ZMQ listener on {} received batch with {} events (engine_seq={}, dp_rank={})",
zmq_endpoint,
batch.events.len(),
engine_seq,
batch.data_parallel_rank.unwrap_or(0)
);
let dp_rank = batch.data_parallel_rank.unwrap_or(0) as u32;
for raw_event in batch.events.into_iter() {
let event_id = next_event_id.fetch_add(1, Ordering::SeqCst);
let event = convert_event(raw_event, event_id, kv_block_size, dp_rank, &warning_count);
if tx.send(event).is_err() {
tracing::warn!("Failed to send message to channel - receiver dropped");
exit_reason = "channel receiver dropped";
break 'main;
}
messages_processed += 1;
}
}
}
}
tracing::debug!(
"ZMQ listener exiting, reason: {}, messages processed: {}",
exit_reason,
messages_processed
);
}
fn convert_event(
raw: RawKvEvent,
event_id: u64,
kv_block_size: u32,
dp_rank: u32,
warning_count: &Arc<AtomicU32>,
) -> KvCacheEvent {
match raw {
RawKvEvent::BlockStored {
block_hashes,
parent_block_hash,
token_ids,
block_size,
lora_name,
block_mm_infos,
medium: _,
} => {
{
let mut seen = HashSet::with_capacity(block_hashes.len() + 1);
if let Some(parent) = parent_block_hash {
seen.insert(parent.into_u64());
}
let has_duplicate = block_hashes.iter().any(|h| !seen.insert(h.into_u64()));
if has_duplicate {
tracing::warn!(
event_id,
"Self-referencing block detected: duplicate hash in store event; dropping"
);
return KvCacheEvent {
event_id,
data: KvCacheEventData::Cleared,
dp_rank,
};
}
}
let num_block_tokens = vec![block_size as u64; block_hashes.len()];
let block_hashes_u64: Vec<u64> = block_hashes
.into_iter()
.map(BlockHashValue::into_u64)
.collect();
KvCacheEvent {
event_id,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: parent_block_hash
.map(BlockHashValue::into_u64)
.map(ExternalSequenceBlockHash::from),
blocks: create_stored_blocks(
kv_block_size,
&token_ids,
&num_block_tokens,
&block_hashes_u64,
lora_name.as_deref(),
warning_count,
block_mm_infos.as_deref(),
),
}),
dp_rank,
}
}
RawKvEvent::BlockRemoved { block_hashes, .. } => {
let hashes = block_hashes
.into_iter()
.map(BlockHashValue::into_u64)
.map(ExternalSequenceBlockHash::from)
.collect();
KvCacheEvent {
event_id,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: hashes,
}),
dp_rank,
}
}
RawKvEvent::AllBlocksCleared => KvCacheEvent {
event_id,
data: KvCacheEventData::Cleared,
dp_rank,
},
}
}
pub fn create_stored_block_from_parts(
kv_block_size: u32,
block_hash: u64,
token_ids: &[u32],
lora_name: Option<&str>,
mm_extra_info: Option<BlockExtraInfo>,
) -> KvCacheStoredBlockData {
let block_mm_infos = mm_extra_info.as_ref().map(|info| vec![Some(info.clone())]);
let tokens_hash = compute_block_hash_for_seq(
token_ids,
kv_block_size,
block_mm_infos.as_deref(),
lora_name,
)[0];
tracing::trace!(
"Creating stored block: external_block_hash={}, tokens_hash={}, token_ids={:?}, kv_block_size={}, mm_extra_info={:?}",
block_hash,
tokens_hash.0,
token_ids,
kv_block_size,
mm_extra_info
);
KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash::from(block_hash),
tokens_hash,
mm_extra_info,
}
}
pub fn create_stored_blocks(
kv_block_size: u32,
token_ids: &[u32],
num_block_tokens: &[u64],
block_hashes: &[u64],
lora_name: Option<&str>,
warning_count: &Arc<AtomicU32>,
block_mm_infos: Option<&[Option<BlockExtraInfo>]>,
) -> Vec<KvCacheStoredBlockData> {
let mut blocks: Vec<KvCacheStoredBlockData> = Vec::new();
let mut token_offset: usize = 0;
for (block_idx, (num_tokens_it, block_hash_it)) in
num_block_tokens.iter().zip(block_hashes.iter()).enumerate()
{
if *num_tokens_it != kv_block_size as u64 {
if warning_count.fetch_add(1, Ordering::Relaxed) < 3 {
tracing::warn!(
"Block not published. Block size must be {} tokens to be published. Block size is: {}",
kv_block_size,
*num_tokens_it
);
}
break;
}
let tokens = &token_ids[token_offset..(token_offset + *num_tokens_it as usize)];
let mm_extra_info = block_mm_infos
.and_then(|infos| infos.get(block_idx))
.and_then(|opt| opt.clone());
blocks.push(create_stored_block_from_parts(
kv_block_size,
*block_hash_it,
tokens,
lora_name,
mm_extra_info,
));
token_offset += *num_tokens_it as usize;
}
blocks
}
#[derive(Debug, Serialize)]
struct KvEventBatch {
ts: f64,
events: Vec<RawKvEvent>,
#[serde(alias = "dp_rank")]
data_parallel_rank: Option<i32>,
}
impl<'de> Deserialize<'de> for KvEventBatch {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let arr: (f64, Vec<RawKvEvent>, Option<i32>) = Deserialize::deserialize(deserializer)?;
Ok(KvEventBatch {
ts: arr.0,
events: arr.1,
data_parallel_rank: arr.2,
})
}
}
#[derive(Debug, Serialize, Deserialize, Clone, Copy)]
#[serde(untagged)]
pub(crate) enum BlockHashValue {
Signed(i64),
Unsigned(u64),
}
impl BlockHashValue {
pub(crate) fn into_u64(self) -> u64 {
match self {
BlockHashValue::Signed(v) => v as u64,
BlockHashValue::Unsigned(v) => v,
}
}
}
#[derive(Debug, Serialize, Clone)]
#[serde(tag = "type")] pub(crate) enum RawKvEvent {
BlockStored {
block_hashes: Vec<BlockHashValue>,
parent_block_hash: Option<BlockHashValue>,
token_ids: Vec<u32>,
block_size: usize,
#[serde(skip_serializing_if = "Option::is_none")]
medium: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
lora_name: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
block_mm_infos: Option<Vec<Option<BlockExtraInfo>>>,
},
BlockRemoved {
block_hashes: Vec<BlockHashValue>,
#[serde(skip_serializing_if = "Option::is_none")]
medium: Option<String>,
},
AllBlocksCleared,
}
fn parse_mm_hash_from_extra_key(s: &str) -> Option<u64> {
if s.len() == 64 && s.chars().all(|c| c.is_ascii_hexdigit()) {
return u64::from_str_radix(&s[..16], 16).ok();
}
None
}
fn extra_keys_to_block_mm_infos(
extra_keys: Option<Vec<Option<Vec<String>>>>,
) -> Option<Vec<Option<BlockExtraInfo>>> {
let extra_keys = extra_keys?;
if extra_keys.is_empty() {
return None;
}
let infos: Vec<Option<BlockExtraInfo>> = extra_keys
.into_iter()
.map(|block_keys| {
let mm_objects: Vec<BlockMmObjectInfo> = block_keys
.unwrap_or_default()
.iter()
.filter_map(|key| parse_mm_hash_from_extra_key(key))
.map(|mm_hash| BlockMmObjectInfo {
mm_hash,
offsets: vec![], })
.collect();
if mm_objects.is_empty() {
None
} else {
Some(BlockExtraInfo { mm_objects })
}
})
.collect();
if infos.iter().all(|i| i.is_none()) {
return None;
}
Some(infos)
}
impl<'de> Deserialize<'de> for RawKvEvent {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_any(RawKvEventVisitor)
}
}
struct RawKvEventVisitor;
impl<'de> Visitor<'de> for RawKvEventVisitor {
type Value = RawKvEvent;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a kv event encoded as a tagged map or sequence")
}
fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
where
A: MapAccess<'de>,
{
let mut event_type: Option<String> = None;
let mut block_hashes: Option<Vec<BlockHashValue>> = None;
let mut parent_block_hash: Option<Option<BlockHashValue>> = None;
let mut token_ids: Option<Vec<u32>> = None;
let mut block_size: Option<usize> = None;
let mut medium: Option<Option<String>> = None;
let mut lora_name: Option<Option<String>> = None;
let mut extra_keys: Option<Option<Vec<Option<Vec<String>>>>> = None;
let mut block_mm_infos: Option<Option<Vec<Option<BlockExtraInfo>>>> = None;
while let Some(key) = map.next_key::<String>()? {
match key.as_str() {
"type" => {
event_type = Some(map.next_value()?);
}
"block_hashes" => {
block_hashes = Some(map.next_value()?);
}
"parent_block_hash" => {
parent_block_hash = Some(map.next_value()?);
}
"token_ids" => {
token_ids = Some(map.next_value()?);
}
"block_size" => {
block_size = Some(map.next_value()?);
}
"medium" => {
medium = Some(map.next_value()?);
}
"lora_name" => {
lora_name = Some(map.next_value()?);
}
"extra_keys" => {
extra_keys = Some(map.next_value()?);
}
"block_mm_infos" => {
block_mm_infos = Some(map.next_value()?);
}
_ => {
map.next_value::<IgnoredAny>()?;
}
}
}
match event_type.as_deref() {
Some("BlockStored") => {
let block_hashes =
block_hashes.ok_or_else(|| de::Error::missing_field("block_hashes"))?;
let token_ids = token_ids.ok_or_else(|| de::Error::missing_field("token_ids"))?;
let block_size =
block_size.ok_or_else(|| de::Error::missing_field("block_size"))?;
let block_mm_infos = block_mm_infos
.unwrap_or(None)
.or_else(|| extra_keys_to_block_mm_infos(extra_keys.unwrap_or(None)));
Ok(RawKvEvent::BlockStored {
block_hashes,
parent_block_hash: parent_block_hash.unwrap_or(None),
token_ids,
block_size,
medium: medium.unwrap_or(None),
lora_name: lora_name.unwrap_or(None),
block_mm_infos,
})
}
Some("BlockRemoved") => {
let block_hashes =
block_hashes.ok_or_else(|| de::Error::missing_field("block_hashes"))?;
Ok(RawKvEvent::BlockRemoved {
block_hashes,
medium: medium.unwrap_or(None),
})
}
Some("AllBlocksCleared") => Ok(RawKvEvent::AllBlocksCleared),
Some(other) => Err(de::Error::unknown_variant(
other,
&["BlockStored", "BlockRemoved", "AllBlocksCleared"],
)),
None => Err(de::Error::missing_field("type")),
}
}
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: SeqAccess<'de>,
{
let tag: Option<String> = seq.next_element()?;
let Some(tag) = tag else {
return Err(de::Error::invalid_length(
0,
&"sequence must start with event tag",
));
};
match tag.as_str() {
"BlockStored" => {
let block_hashes: Vec<BlockHashValue> = seq
.next_element()?
.ok_or_else(|| de::Error::invalid_length(1, &"missing block_hashes"))?;
let parent_block_hash: Option<BlockHashValue> = seq.next_element()?.unwrap_or(None);
let token_ids: Vec<u32> = seq
.next_element()?
.ok_or_else(|| de::Error::invalid_length(3, &"missing token_ids"))?;
let block_size: usize = seq
.next_element()?
.ok_or_else(|| de::Error::invalid_length(4, &"missing block_size"))?;
let _lora_id: Option<u64> = seq.next_element()?.unwrap_or(None);
let medium: Option<String> = seq.next_element()?.unwrap_or(None);
let lora_name: Option<String> = seq.next_element()?.unwrap_or(None);
let extra_keys: Option<Vec<Option<Vec<String>>>> =
seq.next_element()?.unwrap_or(None);
let block_mm_infos: Option<Vec<Option<BlockExtraInfo>>> =
seq.next_element()?.unwrap_or(None);
while seq.next_element::<IgnoredAny>()?.is_some() {}
let block_mm_infos =
block_mm_infos.or_else(|| extra_keys_to_block_mm_infos(extra_keys));
Ok(RawKvEvent::BlockStored {
block_hashes,
parent_block_hash,
token_ids,
block_size,
medium,
lora_name,
block_mm_infos,
})
}
"BlockRemoved" => {
let block_hashes: Vec<BlockHashValue> = seq
.next_element()?
.ok_or_else(|| de::Error::invalid_length(1, &"missing block_hashes"))?;
let medium: Option<String> = seq.next_element()?.unwrap_or(None);
while seq.next_element::<IgnoredAny>()?.is_some() {}
Ok(RawKvEvent::BlockRemoved {
block_hashes,
medium,
})
}
"AllBlocksCleared" => {
while seq.next_element::<IgnoredAny>()?.is_some() {}
Ok(RawKvEvent::AllBlocksCleared)
}
other => Err(de::Error::unknown_variant(
other,
&["BlockStored", "BlockRemoved", "AllBlocksCleared"],
)),
}
}
}
#[derive(Debug, Clone, Default, PartialEq)]
struct WorkerMetrics {
dp_rank: DpRank,
active_decode_blocks: u64,
}
pub struct WorkerMetricsPublisher {
tx: tokio::sync::watch::Sender<WorkerMetrics>,
rx: tokio::sync::watch::Receiver<WorkerMetrics>,
}
impl WorkerMetricsPublisher {
pub fn new() -> Result<Self> {
let (tx, rx) = tokio::sync::watch::channel(WorkerMetrics::default());
Ok(WorkerMetricsPublisher { tx, rx })
}
pub fn publish(&self, dp_rank: Option<DpRank>, active_decode_blocks: u64) -> Result<()> {
let metrics = WorkerMetrics {
dp_rank: dp_rank.unwrap_or(0),
active_decode_blocks,
};
tracing::trace!(
"Publish metrics: dp_rank={}, active_decode_blocks={}",
metrics.dp_rank,
metrics.active_decode_blocks
);
self.tx
.send(metrics)
.map_err(|_| anyhow::anyhow!("metrics channel closed"))
}
pub async fn create_endpoint(&self, component: Component) -> Result<()> {
let worker_id = component.drt().connection_id();
self.start_nats_metrics_publishing(component.namespace().clone(), worker_id);
Ok(())
}
fn start_nats_metrics_publishing(&self, namespace: Namespace, worker_id: u64) {
let nats_rx = self.rx.clone();
tokio::spawn(async move {
let event_publisher =
match EventPublisher::for_namespace(&namespace, KV_METRICS_SUBJECT).await {
Ok(publisher) => publisher,
Err(e) => {
tracing::error!("Failed to create metrics publisher: {}", e);
return;
}
};
let mut rx = nats_rx;
let mut last_metrics: Option<WorkerMetrics> = None;
let mut pending_publish: Option<WorkerMetrics> = None;
let mut publish_timer =
Box::pin(tokio::time::sleep(tokio::time::Duration::from_secs(0)));
publish_timer.as_mut().reset(tokio::time::Instant::now());
loop {
tokio::select! {
result = rx.changed() => {
if result.is_err() {
tracing::debug!(
"Metrics publisher sender dropped, stopping NATS background task"
);
break;
}
let metrics = rx.borrow_and_update().clone();
let has_changed = last_metrics.as_ref() != Some(&metrics);
if has_changed {
pending_publish = Some(metrics.clone());
last_metrics = Some(metrics);
publish_timer.as_mut().reset(
tokio::time::Instant::now() + tokio::time::Duration::from_millis(1)
);
}
}
_ = &mut publish_timer => {
if let Some(metrics) = pending_publish.take() {
let active_load = ActiveLoad {
worker_id,
dp_rank: metrics.dp_rank,
active_decode_blocks: Some(metrics.active_decode_blocks),
active_prefill_tokens: None,
};
if let Err(e) = event_publisher.publish(&active_load).await {
tracing::warn!("Failed to publish metrics: {}", e);
}
}
publish_timer.as_mut().reset(
tokio::time::Instant::now() + tokio::time::Duration::from_secs(3600)
);
}
}
}
});
}
}
#[cfg(test)]
mod test_event_processing {
use super::*;
use crate::kv_router::protocols::compute_block_hash_for_seq;
#[test]
fn test_create_stored_block_from_parts() {
let kv_block_size = 4;
let token_ids = vec![10, 20, 30, 40];
let blk_hash = 0xdead_beef;
let stored =
create_stored_block_from_parts(kv_block_size, blk_hash, &token_ids, None, None);
assert_eq!(stored.block_hash.0, blk_hash);
let expected_hash = compute_block_hash_for_seq(&token_ids, 4, None, None)[0];
assert_eq!(stored.tokens_hash, expected_hash);
assert!(stored.mm_extra_info.is_none());
}
#[test]
fn test_create_stored_blocks_ok() {
let kv_block_size = 4;
let token_ids = vec![1, 2, 3, 4, 5, 6, 7, 8];
let num_block_tokens = vec![4_u64, 4_u64];
let block_hashes = vec![111_u64, 222_u64];
let blocks = create_stored_blocks(
kv_block_size,
&token_ids,
&num_block_tokens,
&block_hashes,
None,
&Arc::new(AtomicU32::new(0)),
None,
);
assert_eq!(blocks.len(), 2);
assert_eq!(blocks[0].block_hash.0, 111);
assert_eq!(blocks[1].block_hash.0, 222);
}
#[test]
fn test_create_stored_blocks_wrong_size_triggers_warning() {
let kv_block_size = 4;
let token_ids = vec![1, 2, 3, 4, 5, 6, 7];
let num_block_tokens = vec![4_u64, 3_u64];
let block_hashes = vec![111_u64, 222_u64];
let warning_count = Arc::new(AtomicU32::new(0));
let blocks = create_stored_blocks(
kv_block_size,
&token_ids,
&num_block_tokens,
&block_hashes,
None,
&warning_count,
None,
);
assert!(blocks.len() == 1);
assert!(warning_count.load(Ordering::Relaxed) == 1)
}
#[test]
fn test_convert_event_block_stored() {
let kv_block_size = 4;
let raw_evt = RawKvEvent::BlockStored {
block_hashes: vec![BlockHashValue::Unsigned(10), BlockHashValue::Unsigned(11)],
parent_block_hash: Some(BlockHashValue::Unsigned(99)),
token_ids: vec![1, 2, 3, 4, 5, 6, 7, 8],
block_size: 4,
medium: None,
lora_name: None,
block_mm_infos: None,
};
let out = convert_event(raw_evt, 42, kv_block_size, 0, &Arc::new(AtomicU32::new(0)));
assert!(matches!(out.data, KvCacheEventData::Stored(_)));
}
#[test]
fn test_convert_event_with_lora_name() {
let kv_block_size = 4;
let token_ids = vec![1, 2, 3, 4];
let base_evt = RawKvEvent::BlockStored {
block_hashes: vec![BlockHashValue::Unsigned(10)],
parent_block_hash: None,
token_ids: token_ids.clone(),
block_size: 4,
medium: None,
lora_name: None,
block_mm_infos: None,
};
let lora_evt = RawKvEvent::BlockStored {
block_hashes: vec![BlockHashValue::Unsigned(10)],
parent_block_hash: None,
token_ids: token_ids.clone(),
block_size: 4,
medium: None,
lora_name: Some("my-lora".to_string()),
block_mm_infos: None,
};
let wc = Arc::new(AtomicU32::new(0));
let base_out = convert_event(base_evt, 1, kv_block_size, 0, &wc);
let lora_out = convert_event(lora_evt, 2, kv_block_size, 0, &wc);
let base_hash = match &base_out.data {
KvCacheEventData::Stored(s) => s.blocks[0].tokens_hash,
_ => panic!("expected Stored"),
};
let lora_hash = match &lora_out.data {
KvCacheEventData::Stored(s) => s.blocks[0].tokens_hash,
_ => panic!("expected Stored"),
};
assert_ne!(
base_hash, lora_hash,
"LoRA blocks must produce distinct tokens_hash"
);
}
#[test]
fn test_convert_event_lora_name_none_is_base_model() {
let kv_block_size = 4;
let token_ids = vec![1, 2, 3, 4];
let wc = Arc::new(AtomicU32::new(0));
let evt1 = RawKvEvent::BlockStored {
block_hashes: vec![BlockHashValue::Unsigned(10)],
parent_block_hash: None,
token_ids: token_ids.clone(),
block_size: 4,
medium: None,
lora_name: None,
block_mm_infos: None,
};
let evt2 = RawKvEvent::BlockStored {
block_hashes: vec![BlockHashValue::Unsigned(10)],
parent_block_hash: None,
token_ids: token_ids.clone(),
block_size: 4,
medium: None,
lora_name: None,
block_mm_infos: None,
};
let out1 = convert_event(evt1, 1, kv_block_size, 0, &wc);
let out2 = convert_event(evt2, 2, kv_block_size, 0, &wc);
let hash1 = match &out1.data {
KvCacheEventData::Stored(s) => s.blocks[0].tokens_hash,
_ => panic!("expected Stored"),
};
let hash2 = match &out2.data {
KvCacheEventData::Stored(s) => s.blocks[0].tokens_hash,
_ => panic!("expected Stored"),
};
assert_eq!(
hash1, hash2,
"Two base-model events with same tokens should produce same hash"
);
}
#[test]
fn test_backward_compat_deserialize_map_with_lora_id_no_lora_name() {
#[derive(serde::Serialize)]
struct OldFormatEvent {
#[serde(rename = "type")]
event_type: &'static str,
block_hashes: Vec<u64>,
parent_block_hash: Option<u64>,
token_ids: Vec<u32>,
block_size: usize,
lora_id: Option<u64>,
}
let payload = rmps::to_vec(&OldFormatEvent {
event_type: "BlockStored",
block_hashes: vec![42],
parent_block_hash: None,
token_ids: vec![1, 2, 3, 4],
block_size: 4,
lora_id: Some(5),
})
.unwrap();
let event: RawKvEvent = rmps::from_slice(&payload).unwrap();
let RawKvEvent::BlockStored { lora_name, .. } = event else {
panic!("expected BlockStored");
};
assert!(
lora_name.is_none(),
"old-format payloads with lora_id but no lora_name should deserialize with lora_name=None"
);
}
#[test]
fn test_backward_compat_deserialize_seq_with_lora_id_no_lora_name() {
let payload = rmps::to_vec(&(
"BlockStored",
vec![42_u64],
None::<u64>,
vec![1_u32, 2, 3, 4],
4_usize,
Some(5_u64), ))
.unwrap();
let event: RawKvEvent = rmps::from_slice(&payload).unwrap();
let RawKvEvent::BlockStored { lora_name, .. } = event else {
panic!("expected BlockStored");
};
assert!(
lora_name.is_none(),
"old seq-format payloads with lora_id should deserialize with lora_name=None"
);
}
#[test]
fn test_convert_event_block_removed() {
let kv_block_size = 4;
let raw_evt = RawKvEvent::BlockRemoved {
block_hashes: vec![BlockHashValue::Unsigned(123), BlockHashValue::Signed(456)],
medium: None,
};
let out = convert_event(raw_evt, 7, kv_block_size, 0, &Arc::new(AtomicU32::new(0)));
assert!(matches!(out.data, KvCacheEventData::Removed(_)));
}
#[test]
fn test_convert_event_all_blocks_cleared() {
let kv_block_size = 4;
let raw_evt = RawKvEvent::AllBlocksCleared;
let out = convert_event(raw_evt, 1, kv_block_size, 0, &Arc::new(AtomicU32::new(0)));
assert!(matches!(out.data, KvCacheEventData::Cleared));
}
#[test]
fn test_parse_mm_hash_from_extra_key() {
assert_eq!(
parse_mm_hash_from_extra_key(
"0123456789abcdef00112233445566778899aabbccddeefffedcba9876543210"
),
Some(0x0123_4567_89ab_cdef)
);
assert_eq!(parse_mm_hash_from_extra_key("123"), None);
assert_eq!(parse_mm_hash_from_extra_key("not_a_hash"), None);
}
#[test]
fn test_extra_keys_to_block_mm_infos() {
let mm_hash =
"0123456789abcdef00112233445566778899aabbccddeefffedcba9876543210".to_string();
let infos = extra_keys_to_block_mm_infos(Some(vec![
Some(vec![mm_hash.clone()]),
None,
Some(vec!["invalid".to_string(), mm_hash]),
]))
.expect("expected parsed MM infos");
assert_eq!(infos.len(), 3);
assert_eq!(
infos[0].as_ref().unwrap().mm_objects[0].mm_hash,
0x0123_4567_89ab_cdef
);
assert!(infos[1].is_none());
assert_eq!(
infos[2].as_ref().unwrap().mm_objects[0].mm_hash,
0x0123_4567_89ab_cdef
);
}
#[test]
fn test_seq_block_stored_field8_supports_extra_keys() {
let mm_hash =
"0123456789abcdef00112233445566778899aabbccddeefffedcba9876543210".to_string();
let extra_keys_payload = rmps::to_vec(&(
"BlockStored",
vec![10_u64],
None::<u64>,
vec![1_u32, 2, 3, 4],
4_usize,
None::<u64>,
None::<String>,
None::<String>,
vec![Some(vec![mm_hash])],
))
.unwrap();
let extra_keys_event: RawKvEvent = rmps::from_slice(&extra_keys_payload).unwrap();
let RawKvEvent::BlockStored {
lora_name,
block_mm_infos,
..
} = extra_keys_event
else {
panic!("expected BlockStored");
};
assert!(lora_name.is_none());
assert_eq!(
block_mm_infos.unwrap()[0].as_ref().unwrap().mm_objects[0].mm_hash,
0x0123_4567_89ab_cdef
);
}
#[test]
fn test_map_block_stored_supports_extra_keys() {
#[derive(serde::Serialize)]
struct MapBlockStoredEvent {
#[serde(rename = "type")]
event_type: &'static str,
block_hashes: Vec<u64>,
parent_block_hash: Option<u64>,
token_ids: Vec<u32>,
block_size: usize,
lora_id: Option<u64>,
medium: Option<String>,
lora_name: Option<String>,
extra_keys: Option<Vec<Option<Vec<String>>>>,
}
let payload = rmps::to_vec(&MapBlockStoredEvent {
event_type: "BlockStored",
block_hashes: vec![10],
parent_block_hash: None,
token_ids: vec![1, 2, 3, 4],
block_size: 4,
lora_id: None,
medium: Some("GPU".to_string()),
lora_name: None,
extra_keys: Some(vec![Some(vec![
"0123456789abcdef00112233445566778899aabbccddeefffedcba9876543210".to_string(),
])]),
})
.unwrap();
let event: RawKvEvent = rmps::from_slice(&payload).unwrap();
let RawKvEvent::BlockStored { block_mm_infos, .. } = event else {
panic!("expected BlockStored");
};
assert_eq!(
block_mm_infos.unwrap()[0].as_ref().unwrap().mm_objects[0].mm_hash,
0x0123_4567_89ab_cdef
);
}
}
#[cfg(test)]
mod tests_startup_helpers {
use super::*;
use crate::kv_router::KvIndexer;
use crate::kv_router::indexer::KvIndexerInterface;
use crate::kv_router::protocols::{ExternalSequenceBlockHash, LocalBlockHash};
use bytes::Bytes;
use std::sync::{Arc, Mutex};
use zeromq::{PubSocket, Socket, SocketSend, ZmqMessage};
type PublishedEvents = Arc<Mutex<Vec<(String, Vec<u8>)>>>;
#[derive(Default)]
struct MockComponent {
published: PublishedEvents,
}
impl MockComponent {
fn new() -> (Self, PublishedEvents) {
let published = Arc::new(Mutex::new(Vec::new()));
(
Self {
published: published.clone(),
},
published,
)
}
}
#[async_trait::async_trait]
impl EventSink for MockComponent {
async fn publish_event(&self, event: &RouterEvent) -> anyhow::Result<()> {
let bytes = rmp_serde::to_vec(event).unwrap();
self.published
.lock()
.unwrap()
.push((KV_EVENT_SUBJECT.to_string(), bytes));
Ok(())
}
}
#[tokio::test]
async fn test_start_event_processor() {
let (component, published) = MockComponent::new();
let event = KvCacheEvent {
event_id: 1,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: vec![ExternalSequenceBlockHash(1), ExternalSequenceBlockHash(2)],
}),
dp_rank: 0,
};
let token = CancellationToken::new();
let (tx, rx) = mpsc::unbounded_channel::<KvCacheEvent>();
tx.send(event).unwrap();
drop(tx);
let handle = tokio::spawn(start_event_processor(component, 1, token, rx, None));
tokio::time::timeout(tokio::time::Duration::from_secs(1), handle)
.await
.unwrap()
.unwrap();
let published = published.lock().unwrap();
assert_eq!(published.len(), 1);
let (subject, _) = &published[0];
assert_eq!(subject, KV_EVENT_SUBJECT);
}
#[tokio::test]
async fn test_start_event_processor_with_local_indexer() {
let (component, published) = MockComponent::new();
let token = CancellationToken::new();
let metrics = Arc::new(KvIndexerMetrics::new_unregistered());
let local_indexer = Arc::new(LocalKvIndexer::new(token.clone(), 4, metrics, 100));
let event = KvCacheEvent {
event_id: 1,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: vec![
KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(100),
tokens_hash: LocalBlockHash(200),
mm_extra_info: None,
},
KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(101),
tokens_hash: LocalBlockHash(201),
mm_extra_info: None,
},
],
}),
dp_rank: 0,
};
let (tx, rx) = mpsc::unbounded_channel::<KvCacheEvent>();
tx.send(event).unwrap();
drop(tx);
let handle = tokio::spawn(start_event_processor(
component,
1,
token.clone(),
rx,
Some(local_indexer.clone()), ));
tokio::time::timeout(tokio::time::Duration::from_secs(1), handle)
.await
.unwrap()
.unwrap();
{
let published_events = published.lock().unwrap();
assert_eq!(published_events.len(), 1);
let (subject, _) = &published_events[0];
assert_eq!(subject, KV_EVENT_SUBJECT);
}
let get_workers_tx = local_indexer.get_workers_sender();
let mut found = false;
for _ in 0..20 {
let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
get_workers_tx
.send(crate::kv_router::indexer::GetWorkersRequest { resp: resp_tx })
.await
.unwrap();
let workers: Vec<u64> = resp_rx.await.unwrap();
if workers.contains(&1) {
found = true;
break;
}
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
}
assert!(
found,
"Worker 1 was not found in the indexer after processing"
);
token.cancel();
}
#[tokio::test]
async fn test_event_processor_block_removed_with_local_indexer() {
let (component, published) = MockComponent::new();
let token = CancellationToken::new();
let metrics = Arc::new(KvIndexerMetrics::new_unregistered());
let local_indexer = Arc::new(LocalKvIndexer::new(token.clone(), 4, metrics, 100));
let store_event = KvCacheEvent {
event_id: 1,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: vec![KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(100),
tokens_hash: LocalBlockHash(200),
mm_extra_info: None,
}],
}),
dp_rank: 0,
};
let (tx, rx) = mpsc::unbounded_channel::<KvCacheEvent>();
tx.send(store_event).unwrap();
let handle = tokio::spawn(start_event_processor(
component,
1,
token.clone(),
rx,
Some(local_indexer.clone()),
));
let remove_event = KvCacheEvent {
event_id: 2,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: vec![ExternalSequenceBlockHash(100)],
}),
dp_rank: 0,
};
tx.send(remove_event).unwrap();
drop(tx);
tokio::time::timeout(tokio::time::Duration::from_secs(1), handle)
.await
.unwrap()
.unwrap();
let mut no_blocks = false;
for _ in 0..20 {
let scores = local_indexer
.find_matches(vec![LocalBlockHash(200)])
.await
.unwrap();
if scores.scores.is_empty() {
no_blocks = true;
break;
}
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
}
assert!(no_blocks, "worker should have no blocks after removal");
let published = published.lock().unwrap();
assert_eq!(
published.len(),
2,
"expected 2 published events, found {}",
published.len()
);
token.cancel();
}
#[tokio::test]
async fn test_event_processor_all_blocks_cleared_with_local_indexer() {
let (component, published) = MockComponent::new();
let token = CancellationToken::new();
let metrics = Arc::new(KvIndexerMetrics::new_unregistered());
let local_indexer = Arc::new(LocalKvIndexer::new(token.clone(), 4, metrics, 100));
let store_event = KvCacheEvent {
event_id: 1,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: vec![KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(100),
tokens_hash: LocalBlockHash(200),
mm_extra_info: None,
}],
}),
dp_rank: 0,
};
let (tx, rx) = mpsc::unbounded_channel::<KvCacheEvent>();
tx.send(store_event).unwrap();
let clear_event = KvCacheEvent {
event_id: 2,
data: KvCacheEventData::Cleared,
dp_rank: 0,
};
tx.send(clear_event).unwrap();
drop(tx);
let handle = tokio::spawn(start_event_processor(
component,
1,
token.clone(),
rx,
Some(local_indexer.clone()),
));
tokio::time::timeout(tokio::time::Duration::from_secs(1), handle)
.await
.unwrap()
.unwrap();
let mut no_blocks = false;
for _ in 0..20 {
let scores = local_indexer
.find_matches(vec![LocalBlockHash(200)])
.await
.unwrap();
if scores.scores.is_empty() {
no_blocks = true;
break;
}
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
}
assert!(no_blocks, "worker should have no blocks after clearing");
let published = published.lock().unwrap();
assert_eq!(
published.len(),
2,
"expected 2 published events, found {}",
published.len()
);
token.cancel();
}
#[tokio::test]
async fn test_event_processor_local_indexer_failure_continues() {
let (component, published) = MockComponent::new();
let token = CancellationToken::new();
let metrics = Arc::new(KvIndexerMetrics::new_unregistered());
let local_indexer = Arc::new(LocalKvIndexer::new(token.clone(), 4, metrics, 100));
token.cancel();
let event = KvCacheEvent {
event_id: 1,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: vec![ExternalSequenceBlockHash(1)],
}),
dp_rank: 0,
};
let new_token = CancellationToken::new();
let (tx, rx) = mpsc::unbounded_channel::<KvCacheEvent>();
tx.send(event).unwrap();
drop(tx);
let handle = tokio::spawn(start_event_processor(
component,
1,
new_token,
rx,
Some(local_indexer),
));
tokio::time::timeout(tokio::time::Duration::from_secs(1), handle)
.await
.unwrap()
.unwrap();
let published_events = published.lock().unwrap();
assert_eq!(published_events.len(), 1);
}
#[tokio::test]
async fn test_start_zmq_listener_pushes_to_channel() {
let (tx, mut rx) = mpsc::unbounded_channel::<KvCacheEvent>();
let endpoint = "tcp://127.0.0.1:15555";
let topic = "".to_string();
let mut pub_socket = PubSocket::new();
pub_socket.bind(endpoint).await.unwrap();
let token = dynamo_runtime::CancellationToken::new();
let next_event_id = Arc::new(AtomicU64::new(0));
let listener_handle = tokio::spawn({
let token = token.clone();
start_zmq_listener(endpoint.to_string(), topic, tx, token, 4, next_event_id)
});
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
let seq: u64 = 77;
let events = vec![RawKvEvent::BlockStored {
block_hashes: vec![BlockHashValue::Unsigned(42)],
parent_block_hash: None,
token_ids: vec![0, 1, 2, 3],
block_size: 4,
medium: None,
lora_name: None,
block_mm_infos: None,
}];
let batch = KvEventBatch {
ts: 0.0,
events,
data_parallel_rank: Some(1),
};
let payload = Bytes::from(rmps::to_vec(&batch).unwrap());
let frames = vec![
Bytes::from(""),
Bytes::from(seq.to_be_bytes().to_vec()),
payload.clone(),
];
let msg = ZmqMessage::try_from(frames).expect("Failed to create ZmqMessage");
pub_socket.send(msg).await.unwrap();
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
let event = rx.try_recv().expect("no message received");
let KvCacheEventData::Stored(KvCacheStoreData {
parent_hash,
blocks,
}) = event.data
else {
panic!("expected KvCacheStoreData");
};
assert!(parent_hash.is_none());
assert_eq!(blocks.len(), 1);
assert_eq!(blocks[0].block_hash.0, 42);
token.cancel();
let _ = listener_handle.await;
}
#[tokio::test]
async fn test_distributed_kvindexer_recovery_from_outage() {
let worker_1_id = 1u64;
let block_size = 4u32;
let token = CancellationToken::new();
let (worker_component, worker_published) = MockComponent::new();
let local_indexer_1 = Arc::new(LocalKvIndexer::new(
token.clone(),
block_size,
Arc::new(KvIndexerMetrics::new_unregistered()),
100, ));
let (worker_tx, worker_rx) = mpsc::unbounded_channel::<KvCacheEvent>();
tokio::spawn(start_event_processor(
worker_component,
worker_1_id,
token.clone(),
worker_rx,
Some(local_indexer_1.clone()),
));
let router_indexer = Arc::new(KvIndexer::new(
token.clone(),
block_size,
Arc::new(KvIndexerMetrics::new_unregistered()),
));
let event_1 = KvCacheEvent {
event_id: 1,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: vec![
KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(100),
tokens_hash: LocalBlockHash(200),
mm_extra_info: None,
},
KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(101),
tokens_hash: LocalBlockHash(201),
mm_extra_info: None,
},
],
}),
dp_rank: 0,
};
worker_tx.send(event_1.clone()).unwrap();
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
let (subject, bytes) = {
let published = worker_published.lock().unwrap();
assert_eq!(published.len(), 1, "Worker should have published 1 event");
(published[0].0.clone(), published[0].1.clone())
}; assert_eq!(subject, KV_EVENT_SUBJECT);
let router_event: RouterEvent = rmp_serde::from_slice(&bytes).unwrap();
router_indexer
.event_sender()
.send(router_event)
.await
.unwrap();
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
let get_workers_tx = router_indexer.get_workers_sender();
let mut router_has_worker = false;
for _ in 0..20 {
let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
get_workers_tx
.send(crate::kv_router::indexer::GetWorkersRequest { resp: resp_tx })
.await
.unwrap();
let workers: Vec<u64> = resp_rx.await.unwrap();
if workers.contains(&worker_1_id) {
router_has_worker = true;
break;
}
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
}
assert!(
router_has_worker,
"Router should see worker 1 after normal operation"
);
let buffered = local_indexer_1.get_all_events_in_buffer();
assert_eq!(buffered.len(), 1, "Local indexer should buffer 1 event");
let event_2 = KvCacheEvent {
event_id: 2,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: vec![
KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(100), tokens_hash: LocalBlockHash(200),
mm_extra_info: None,
},
KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(102), tokens_hash: LocalBlockHash(202),
mm_extra_info: None,
},
],
}),
dp_rank: 0,
};
worker_tx.send(event_2.clone()).unwrap(); tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
{
let published = worker_published.lock().unwrap();
assert_eq!(
published.len(),
2,
"Worker should have published 2 events total"
);
}
let buffered = local_indexer_1.get_all_events_in_buffer();
assert_eq!(
buffered.len(),
2,
"Local indexer should have both events during outage"
);
let block_hashes_2 = vec![LocalBlockHash(200), LocalBlockHash(202)];
let overlap = router_indexer
.find_matches(block_hashes_2.clone())
.await
.unwrap();
let router_overlap = overlap
.scores
.get(&crate::kv_router::protocols::WorkerWithDpRank::from_worker_id(worker_1_id))
.copied()
.unwrap_or(0);
assert_eq!(
router_overlap, 1,
"Router should only see 1 shared block (not the new block from event_2)"
);
let last_known_id = 1u64; let response = local_indexer_1
.get_events_in_id_range(Some(last_known_id + 1), None)
.await;
let missed_events = match response {
crate::kv_router::indexer::WorkerKvQueryResponse::Events(e) => e,
crate::kv_router::indexer::WorkerKvQueryResponse::TreeDump(e) => e,
crate::kv_router::indexer::WorkerKvQueryResponse::Error(message) => {
panic!("Unexpected error response: {message}")
}
other => panic!("Unexpected response: {:?}", other),
};
assert_eq!(
missed_events.len(),
1,
"Should get 1 missed event (event_2 with id=2)"
);
for router_event in missed_events {
router_indexer
.event_sender()
.send(router_event)
.await
.unwrap();
}
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
let overlap = router_indexer.find_matches(block_hashes_2).await.unwrap();
let router_overlap_after = overlap
.scores
.get(&crate::kv_router::protocols::WorkerWithDpRank::from_worker_id(worker_1_id))
.copied()
.unwrap_or(0);
assert_eq!(
router_overlap_after, 2,
"Router should now see both blocks after recovery"
);
token.cancel();
}
}
#[cfg(test)]
mod test_exponential_backoff {
use super::*;
#[test]
fn test_backoff_calculation_progression() {
assert_eq!(calculate_backoff_ms(0), 10); assert_eq!(calculate_backoff_ms(1), 20); assert_eq!(calculate_backoff_ms(2), 40); assert_eq!(calculate_backoff_ms(3), 80); assert_eq!(calculate_backoff_ms(4), 160); assert_eq!(calculate_backoff_ms(5), 320); assert_eq!(calculate_backoff_ms(6), 640); assert_eq!(calculate_backoff_ms(7), 1280); assert_eq!(calculate_backoff_ms(8), 2560); }
#[test]
fn test_backoff_caps_at_max_exponent() {
assert_eq!(calculate_backoff_ms(8), 2560);
assert_eq!(calculate_backoff_ms(9), 2560); assert_eq!(calculate_backoff_ms(100), 2560); }
#[test]
fn test_backoff_never_exceeds_max() {
for i in 0..20 {
assert!(calculate_backoff_ms(i) <= MAX_BACKOFF_MS);
}
}
#[test]
#[allow(clippy::assertions_on_constants)]
fn test_backoff_constants_are_sane() {
assert!(INITIAL_BACKOFF_MS > 0);
assert!(MAX_BACKOFF_MS > INITIAL_BACKOFF_MS);
assert!(MAX_BACKOFF_EXPONENT <= 10); assert!(MAX_CONSECUTIVE_ERRORS > 0);
let max_calculated = INITIAL_BACKOFF_MS * 2_u64.pow(MAX_BACKOFF_EXPONENT);
assert!(max_calculated <= MAX_BACKOFF_MS);
}
}
#[cfg(all(test, feature = "integration"))]
mod test_integration_publisher {
use super::*;
use crate::kv_router::protocols::ActiveLoad;
use dynamo_runtime::distributed_test_utils::create_test_drt_async;
use dynamo_runtime::transports::event_plane::EventSubscriber;
#[tokio::test]
#[ignore] async fn test_metrics_publishing_behavior() -> Result<()> {
let drt = create_test_drt_async().await;
let namespace = drt.namespace("ns2001".to_string())?;
let mut subscriber = EventSubscriber::for_namespace(&namespace, KV_METRICS_SUBJECT)
.await
.unwrap()
.typed::<ActiveLoad>();
let publisher = WorkerMetricsPublisher::new().unwrap();
let worker_id = 1234;
publisher.start_nats_metrics_publishing(namespace.clone(), worker_id);
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
for i in 0..10 {
publisher.publish(None, (i * 100) as u64).unwrap();
tokio::time::sleep(tokio::time::Duration::from_micros(100)).await;
}
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
let result =
tokio::time::timeout(tokio::time::Duration::from_millis(500), subscriber.next())
.await
.unwrap();
let (_envelope, event) = result.unwrap().unwrap(); assert_eq!(event.worker_id, worker_id);
assert_eq!(event.active_decode_blocks, Some(900)); assert_eq!(event.active_prefill_tokens, None);
let no_msg =
tokio::time::timeout(tokio::time::Duration::from_millis(50), subscriber.next()).await;
assert!(no_msg.is_err(), "Expected no more messages, but found one");
for _ in 0..10 {
publisher.publish(None, 900).unwrap(); tokio::time::sleep(tokio::time::Duration::from_micros(100)).await;
}
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
let no_msg =
tokio::time::timeout(tokio::time::Duration::from_millis(50), subscriber.next()).await;
assert!(
no_msg.is_err(),
"Expected no messages when load metrics don't change"
);
drt.shutdown();
Ok(())
}
}