use super::*;
#[allow(unused_imports)]
use bytes::Bytes;
#[allow(unused_imports)]
use dynamo_kv_router::RouterEventSink;
#[allow(unused_imports)]
use rmp_serde as rmps;
#[allow(unused_imports)]
use std::future::Future;
#[allow(unused_imports)]
use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
use std::time::Duration;
#[cfg(test)]
mod test_event_processing {
use super::*;
use dynamo_kv_router::protocols::{BlockHashOptions, compute_block_hash_for_seq};
#[test]
fn test_create_stored_block_from_parts() {
let kv_block_size = 4;
let token_ids = vec![10, 20, 30, 40];
let blk_hash = 0xdead_beef;
let stored =
create_stored_block_from_parts(kv_block_size, blk_hash, &token_ids, None, None, None);
assert_eq!(stored.block_hash.0, blk_hash);
let expected_hash =
compute_block_hash_for_seq(&token_ids, 4, BlockHashOptions::default())[0];
assert_eq!(stored.tokens_hash, expected_hash);
assert!(stored.mm_extra_info.is_none());
}
#[test]
fn test_create_stored_blocks_ok() {
let kv_block_size = 4;
let token_ids = vec![1, 2, 3, 4, 5, 6, 7, 8];
let num_block_tokens = vec![4_u64, 4_u64];
let block_hashes = vec![111_u64, 222_u64];
let blocks = create_stored_blocks(
kv_block_size,
&token_ids,
&num_block_tokens,
&block_hashes,
None,
&Arc::new(AtomicU32::new(0)),
None,
None,
);
assert_eq!(blocks.len(), 2);
assert_eq!(blocks[0].block_hash.0, 111);
assert_eq!(blocks[1].block_hash.0, 222);
}
#[test]
fn test_create_stored_blocks_wrong_size_triggers_warning() {
let kv_block_size = 4;
let token_ids = vec![1, 2, 3, 4, 5, 6, 7];
let num_block_tokens = vec![4_u64, 3_u64];
let block_hashes = vec![111_u64, 222_u64];
let warning_count = Arc::new(AtomicU32::new(0));
let blocks = create_stored_blocks(
kv_block_size,
&token_ids,
&num_block_tokens,
&block_hashes,
None,
&warning_count,
None,
None,
);
assert!(blocks.len() == 1);
assert!(warning_count.load(Ordering::Relaxed) == 1)
}
#[test]
fn test_convert_event_block_stored() {
let kv_block_size = 4;
let raw_evt = RawKvEvent::BlockStored {
block_hashes: vec![BlockHashValue::Unsigned(10), BlockHashValue::Unsigned(11)],
parent_block_hash: Some(BlockHashValue::Unsigned(99)),
token_ids: vec![1, 2, 3, 4, 5, 6, 7, 8],
block_size: 4,
medium: None,
lora_name: None,
block_mm_infos: None,
is_eagle: None,
};
let out = convert_event(
raw_evt,
42,
kv_block_size,
WorkerWithDpRank::from_worker_id(1),
&Arc::new(AtomicU32::new(0)),
);
assert!(matches!(out.event.data, KvCacheEventData::Stored(_)));
}
#[test]
fn test_convert_event_with_lora_name() {
let kv_block_size = 4;
let token_ids = vec![1, 2, 3, 4];
let base_evt = RawKvEvent::BlockStored {
block_hashes: vec![BlockHashValue::Unsigned(10)],
parent_block_hash: None,
token_ids: token_ids.clone(),
block_size: 4,
medium: None,
lora_name: None,
block_mm_infos: None,
is_eagle: None,
};
let lora_evt = RawKvEvent::BlockStored {
block_hashes: vec![BlockHashValue::Unsigned(10)],
parent_block_hash: None,
token_ids: token_ids.clone(),
block_size: 4,
medium: None,
lora_name: Some("my-lora".to_string()),
block_mm_infos: None,
is_eagle: None,
};
let wc = Arc::new(AtomicU32::new(0));
let base_out = convert_event(
base_evt,
1,
kv_block_size,
WorkerWithDpRank::from_worker_id(1),
&wc,
);
let lora_out = convert_event(
lora_evt,
2,
kv_block_size,
WorkerWithDpRank::from_worker_id(1),
&wc,
);
let base_hash = match &base_out.event.data {
KvCacheEventData::Stored(s) => s.blocks[0].tokens_hash,
_ => panic!("expected Stored"),
};
let lora_hash = match &lora_out.event.data {
KvCacheEventData::Stored(s) => s.blocks[0].tokens_hash,
_ => panic!("expected Stored"),
};
assert_ne!(
base_hash, lora_hash,
"LoRA blocks must produce distinct tokens_hash"
);
}
#[test]
fn test_convert_event_lora_name_none_is_base_model() {
let kv_block_size = 4;
let token_ids = vec![1, 2, 3, 4];
let wc = Arc::new(AtomicU32::new(0));
let evt1 = RawKvEvent::BlockStored {
block_hashes: vec![BlockHashValue::Unsigned(10)],
parent_block_hash: None,
token_ids: token_ids.clone(),
block_size: 4,
medium: None,
lora_name: None,
block_mm_infos: None,
is_eagle: None,
};
let evt2 = RawKvEvent::BlockStored {
block_hashes: vec![BlockHashValue::Unsigned(10)],
parent_block_hash: None,
token_ids: token_ids.clone(),
block_size: 4,
medium: None,
lora_name: None,
block_mm_infos: None,
is_eagle: None,
};
let out1 = convert_event(
evt1,
1,
kv_block_size,
WorkerWithDpRank::from_worker_id(1),
&wc,
);
let out2 = convert_event(
evt2,
2,
kv_block_size,
WorkerWithDpRank::from_worker_id(1),
&wc,
);
let hash1 = match &out1.event.data {
KvCacheEventData::Stored(s) => s.blocks[0].tokens_hash,
_ => panic!("expected Stored"),
};
let hash2 = match &out2.event.data {
KvCacheEventData::Stored(s) => s.blocks[0].tokens_hash,
_ => panic!("expected Stored"),
};
assert_eq!(
hash1, hash2,
"Two base-model events with same tokens should produce same hash"
);
}
#[test]
fn test_backward_compat_deserialize_map_with_lora_id_no_lora_name() {
#[derive(serde::Serialize)]
struct OldFormatEvent {
#[serde(rename = "type")]
event_type: &'static str,
block_hashes: Vec<u64>,
parent_block_hash: Option<u64>,
token_ids: Vec<u32>,
block_size: usize,
lora_id: Option<u64>,
}
let payload = rmps::to_vec(&OldFormatEvent {
event_type: "BlockStored",
block_hashes: vec![42],
parent_block_hash: None,
token_ids: vec![1, 2, 3, 4],
block_size: 4,
lora_id: Some(5),
})
.unwrap();
let event: RawKvEvent = rmps::from_slice(&payload).unwrap();
let RawKvEvent::BlockStored { lora_name, .. } = event else {
panic!("expected BlockStored");
};
assert!(
lora_name.is_none(),
"old-format payloads with lora_id but no lora_name should deserialize with lora_name=None"
);
}
#[test]
fn test_backward_compat_deserialize_seq_with_lora_id_no_lora_name() {
let payload = rmps::to_vec(&(
"BlockStored",
vec![42_u64],
None::<u64>,
vec![1_u32, 2, 3, 4],
4_usize,
Some(5_u64), ))
.unwrap();
let event: RawKvEvent = rmps::from_slice(&payload).unwrap();
let RawKvEvent::BlockStored { lora_name, .. } = event else {
panic!("expected BlockStored");
};
assert!(
lora_name.is_none(),
"old seq-format payloads with lora_id should deserialize with lora_name=None"
);
}
#[test]
fn test_convert_event_block_removed() {
let kv_block_size = 4;
let raw_evt = RawKvEvent::BlockRemoved {
block_hashes: vec![BlockHashValue::Unsigned(123), BlockHashValue::Signed(456)],
medium: None,
};
let out = convert_event(
raw_evt,
7,
kv_block_size,
WorkerWithDpRank::from_worker_id(1),
&Arc::new(AtomicU32::new(0)),
);
assert!(matches!(out.event.data, KvCacheEventData::Removed(_)));
}
#[test]
fn test_convert_event_all_blocks_cleared() {
let kv_block_size = 4;
let raw_evt = RawKvEvent::AllBlocksCleared;
let out = convert_event(
raw_evt,
1,
kv_block_size,
WorkerWithDpRank::from_worker_id(1),
&Arc::new(AtomicU32::new(0)),
);
assert!(matches!(out.event.data, KvCacheEventData::Cleared));
}
#[test]
fn test_parse_mm_hash_from_extra_key() {
assert_eq!(
parse_mm_hash_from_extra_key(
"0123456789abcdef00112233445566778899aabbccddeefffedcba9876543210"
),
Some(0x0123_4567_89ab_cdef)
);
assert_eq!(parse_mm_hash_from_extra_key("123"), None);
assert_eq!(parse_mm_hash_from_extra_key("not_a_hash"), None);
}
#[test]
fn test_extra_keys_to_block_mm_infos() {
let mm_hash =
"0123456789abcdef00112233445566778899aabbccddeefffedcba9876543210".to_string();
let infos = extra_keys_to_block_mm_infos(Some(vec![
Some(vec![ExtraKeyItem::Hash(mm_hash.clone())]),
None,
Some(vec![
ExtraKeyItem::Hash("invalid".to_string()),
ExtraKeyItem::Hash(mm_hash),
]),
]))
.expect("expected parsed MM infos");
assert_eq!(infos.len(), 3);
assert_eq!(
infos[0].as_ref().unwrap().mm_objects[0].mm_hash,
0x0123_4567_89ab_cdef
);
assert!(infos[1].is_none());
assert_eq!(
infos[2].as_ref().unwrap().mm_objects[0].mm_hash,
0x0123_4567_89ab_cdef
);
}
#[test]
fn test_seq_block_stored_field8_supports_extra_keys() {
let mm_hash =
"0123456789abcdef00112233445566778899aabbccddeefffedcba9876543210".to_string();
let extra_keys_payload = rmps::to_vec(&(
"BlockStored",
vec![10_u64],
None::<u64>,
vec![1_u32, 2, 3, 4],
4_usize,
None::<u64>,
None::<String>,
None::<String>,
vec![Some(vec![mm_hash])],
))
.unwrap();
let extra_keys_event: RawKvEvent = rmps::from_slice(&extra_keys_payload).unwrap();
let RawKvEvent::BlockStored {
lora_name,
block_mm_infos,
..
} = extra_keys_event
else {
panic!("expected BlockStored");
};
assert!(lora_name.is_none());
assert_eq!(
block_mm_infos.unwrap()[0].as_ref().unwrap().mm_objects[0].mm_hash,
0x0123_4567_89ab_cdef
);
}
#[test]
fn test_seq_block_stored_field8_supports_tuple_extra_keys() {
let mm_hash =
"0123456789abcdef00112233445566778899aabbccddeefffedcba9876543210".to_string();
let extra_keys_payload = rmps::to_vec(&(
"BlockStored",
vec![10_u64],
None::<u64>,
vec![1_u32, 2, 3, 4],
4_usize,
None::<u64>,
None::<String>,
None::<String>,
vec![Some(vec![(mm_hash, 7_i64)])],
))
.unwrap();
let extra_keys_event: RawKvEvent = rmps::from_slice(&extra_keys_payload).unwrap();
let RawKvEvent::BlockStored { block_mm_infos, .. } = extra_keys_event else {
panic!("expected BlockStored");
};
assert_eq!(
block_mm_infos.unwrap()[0].as_ref().unwrap().mm_objects[0].mm_hash,
0x0123_4567_89ab_cdef
);
}
#[test]
fn test_map_block_stored_supports_extra_keys() {
#[derive(serde::Serialize)]
struct MapBlockStoredEvent {
#[serde(rename = "type")]
event_type: &'static str,
block_hashes: Vec<u64>,
parent_block_hash: Option<u64>,
token_ids: Vec<u32>,
block_size: usize,
lora_id: Option<u64>,
medium: Option<String>,
lora_name: Option<String>,
extra_keys: Option<Vec<Option<Vec<String>>>>,
}
let payload = rmps::to_vec(&MapBlockStoredEvent {
event_type: "BlockStored",
block_hashes: vec![10],
parent_block_hash: None,
token_ids: vec![1, 2, 3, 4],
block_size: 4,
lora_id: None,
medium: Some("GPU".to_string()),
lora_name: None,
extra_keys: Some(vec![Some(vec![
"0123456789abcdef00112233445566778899aabbccddeefffedcba9876543210".to_string(),
])]),
})
.unwrap();
let event: RawKvEvent = rmps::from_slice(&payload).unwrap();
let RawKvEvent::BlockStored { block_mm_infos, .. } = event else {
panic!("expected BlockStored");
};
assert_eq!(
block_mm_infos.unwrap()[0].as_ref().unwrap().mm_objects[0].mm_hash,
0x0123_4567_89ab_cdef
);
}
#[test]
fn test_map_block_stored_supports_tuple_extra_keys() {
type BlockTupleExtraKeys = Option<Vec<Option<Vec<(String, i64)>>>>;
#[derive(serde::Serialize)]
struct MapBlockStoredEvent {
#[serde(rename = "type")]
event_type: &'static str,
block_hashes: Vec<u64>,
parent_block_hash: Option<u64>,
token_ids: Vec<u32>,
block_size: usize,
lora_id: Option<u64>,
medium: Option<String>,
lora_name: Option<String>,
extra_keys: BlockTupleExtraKeys,
}
let mm_hash =
"0123456789abcdef00112233445566778899aabbccddeefffedcba9876543210".to_string();
let payload = rmps::to_vec(&MapBlockStoredEvent {
event_type: "BlockStored",
block_hashes: vec![10],
parent_block_hash: None,
token_ids: vec![1, 2, 3, 4],
block_size: 4,
lora_id: None,
medium: Some("GPU".to_string()),
lora_name: None,
extra_keys: Some(vec![Some(vec![(mm_hash, 3)])]),
})
.unwrap();
let event: RawKvEvent = rmps::from_slice(&payload).unwrap();
let RawKvEvent::BlockStored { block_mm_infos, .. } = event else {
panic!("expected BlockStored");
};
assert_eq!(
block_mm_infos.unwrap()[0].as_ref().unwrap().mm_objects[0].mm_hash,
0x0123_4567_89ab_cdef
);
}
}
#[cfg(test)]
mod tests_startup_helpers {
use super::*;
use crate::utils::zmq::{bind_pub_socket, send_multipart};
use bytes::Bytes;
use dynamo_kv_router::indexer::{
GetWorkersRequest, KvIndexer, KvIndexerInterface, WorkerKvQueryResponse,
};
use dynamo_kv_router::protocols::{ExternalSequenceBlockHash, LocalBlockHash};
use std::sync::{Arc, Mutex};
type PublishedEvents = Arc<Mutex<Vec<(String, Vec<u8>)>>>;
#[derive(Default)]
struct MockComponent {
published: PublishedEvents,
}
impl MockComponent {
fn new() -> (Self, PublishedEvents) {
let published = Arc::new(Mutex::new(Vec::new()));
(
Self {
published: published.clone(),
},
published,
)
}
}
impl RouterEventSink for MockComponent {
fn publish_event(
&self,
event: &RouterEvent,
) -> impl Future<Output = anyhow::Result<()>> + Send {
let bytes = rmp_serde::to_vec(event).unwrap();
self.published
.lock()
.unwrap()
.push((KV_EVENT_SUBJECT.to_string(), bytes));
async { Ok(()) }
}
}
fn local_gpu_event(worker_id: WorkerId, event: KvCacheEvent) -> PlacementEvent {
PlacementEvent::local_gpu(worker_id, event)
}
#[tokio::test]
async fn test_start_event_processor() {
let (component, published) = MockComponent::new();
let event = KvCacheEvent {
event_id: 1,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: vec![ExternalSequenceBlockHash(1), ExternalSequenceBlockHash(2)],
}),
dp_rank: 0,
};
let token = CancellationToken::new();
let (tx, rx) = mpsc::unbounded_channel::<PlacementEvent>();
tx.send(local_gpu_event(1, event)).unwrap();
drop(tx);
let handle = tokio::spawn(start_event_processor(
component,
1,
token,
rx,
None,
Some(10_000),
));
tokio::time::timeout(tokio::time::Duration::from_secs(1), handle)
.await
.unwrap()
.unwrap();
let published = published.lock().unwrap();
assert_eq!(published.len(), 1);
let (subject, _) = &published[0];
assert_eq!(subject, KV_EVENT_SUBJECT);
}
#[tokio::test]
async fn test_start_event_processor_with_local_indexer() {
let (component, published) = MockComponent::new();
let token = CancellationToken::new();
let metrics = Arc::new(KvIndexerMetrics::new_unregistered());
let local_indexer = Arc::new(LocalKvIndexer::new(token.clone(), 4, metrics, 100));
let event = KvCacheEvent {
event_id: 1,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: vec![
KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(100),
tokens_hash: LocalBlockHash(200),
mm_extra_info: None,
},
KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(101),
tokens_hash: LocalBlockHash(201),
mm_extra_info: None,
},
],
}),
dp_rank: 0,
};
let (tx, rx) = mpsc::unbounded_channel::<PlacementEvent>();
tx.send(local_gpu_event(1, event)).unwrap();
drop(tx);
let handle = tokio::spawn(start_event_processor(
component,
1,
token.clone(),
rx,
Some(local_indexer.clone()), Some(10_000),
));
tokio::time::timeout(tokio::time::Duration::from_secs(1), handle)
.await
.unwrap()
.unwrap();
{
let published_events = published.lock().unwrap();
assert_eq!(published_events.len(), 1);
let (subject, _) = &published_events[0];
assert_eq!(subject, KV_EVENT_SUBJECT);
}
let get_workers_tx = local_indexer.get_workers_sender();
let mut found = false;
for _ in 0..20 {
let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
get_workers_tx
.send(GetWorkersRequest { resp: resp_tx })
.await
.unwrap();
let workers: Vec<u64> = resp_rx.await.unwrap();
if workers.contains(&1) {
found = true;
break;
}
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
}
assert!(
found,
"Worker 1 was not found in the indexer after processing"
);
token.cancel();
}
#[tokio::test]
async fn test_event_processor_block_removed_with_local_indexer() {
let (component, published) = MockComponent::new();
let token = CancellationToken::new();
let metrics = Arc::new(KvIndexerMetrics::new_unregistered());
let local_indexer = Arc::new(LocalKvIndexer::new(token.clone(), 4, metrics, 100));
let store_event = KvCacheEvent {
event_id: 1,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: vec![KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(100),
tokens_hash: LocalBlockHash(200),
mm_extra_info: None,
}],
}),
dp_rank: 0,
};
let (tx, rx) = mpsc::unbounded_channel::<PlacementEvent>();
tx.send(local_gpu_event(1, store_event)).unwrap();
let handle = tokio::spawn(start_event_processor(
component,
1,
token.clone(),
rx,
Some(local_indexer.clone()),
Some(10_000),
));
let remove_event = KvCacheEvent {
event_id: 2,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: vec![ExternalSequenceBlockHash(100)],
}),
dp_rank: 0,
};
tx.send(local_gpu_event(1, remove_event)).unwrap();
drop(tx);
tokio::time::timeout(tokio::time::Duration::from_secs(1), handle)
.await
.unwrap()
.unwrap();
let mut no_blocks = false;
for _ in 0..20 {
let scores = local_indexer
.find_matches(vec![LocalBlockHash(200)])
.await
.unwrap();
if scores.scores.is_empty() {
no_blocks = true;
break;
}
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
}
assert!(no_blocks, "worker should have no blocks after removal");
let published = published.lock().unwrap();
assert_eq!(
published.len(),
2,
"expected 2 published events, found {}",
published.len()
);
token.cancel();
}
#[tokio::test]
async fn test_event_processor_all_blocks_cleared_with_local_indexer() {
let (component, published) = MockComponent::new();
let token = CancellationToken::new();
let metrics = Arc::new(KvIndexerMetrics::new_unregistered());
let local_indexer = Arc::new(LocalKvIndexer::new(token.clone(), 4, metrics, 100));
let store_event = KvCacheEvent {
event_id: 1,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: vec![KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(100),
tokens_hash: LocalBlockHash(200),
mm_extra_info: None,
}],
}),
dp_rank: 0,
};
let (tx, rx) = mpsc::unbounded_channel::<PlacementEvent>();
tx.send(local_gpu_event(1, store_event)).unwrap();
let clear_event = KvCacheEvent {
event_id: 2,
data: KvCacheEventData::Cleared,
dp_rank: 0,
};
tx.send(local_gpu_event(1, clear_event)).unwrap();
drop(tx);
let handle = tokio::spawn(start_event_processor(
component,
1,
token.clone(),
rx,
Some(local_indexer.clone()),
Some(10_000),
));
tokio::time::timeout(tokio::time::Duration::from_secs(1), handle)
.await
.unwrap()
.unwrap();
let mut no_blocks = false;
for _ in 0..20 {
let scores = local_indexer
.find_matches(vec![LocalBlockHash(200)])
.await
.unwrap();
if scores.scores.is_empty() {
no_blocks = true;
break;
}
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
}
assert!(no_blocks, "worker should have no blocks after clearing");
let published = published.lock().unwrap();
assert_eq!(
published.len(),
2,
"expected 2 published events, found {}",
published.len()
);
token.cancel();
}
#[tokio::test]
async fn test_event_processor_local_indexer_failure_continues() {
let (component, published) = MockComponent::new();
let token = CancellationToken::new();
let metrics = Arc::new(KvIndexerMetrics::new_unregistered());
let local_indexer = Arc::new(LocalKvIndexer::new(token.clone(), 4, metrics, 100));
token.cancel();
let event = KvCacheEvent {
event_id: 1,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: vec![ExternalSequenceBlockHash(1)],
}),
dp_rank: 0,
};
let new_token = CancellationToken::new();
let (tx, rx) = mpsc::unbounded_channel::<PlacementEvent>();
tx.send(local_gpu_event(1, event)).unwrap();
drop(tx);
let handle = tokio::spawn(start_event_processor(
component,
1,
new_token,
rx,
Some(local_indexer),
Some(10_000),
));
tokio::time::timeout(tokio::time::Duration::from_secs(1), handle)
.await
.unwrap()
.unwrap();
let published_events = published.lock().unwrap();
assert_eq!(published_events.len(), 1);
}
#[tokio::test]
async fn test_start_zmq_listener_pushes_to_channel() {
let (tx, mut rx) = mpsc::unbounded_channel::<PlacementEvent>();
let reserved_listener = reserve_open_port();
let endpoint = format!(
"tcp://127.0.0.1:{}",
reserved_listener
.local_addr()
.expect("failed to read reserved listener address")
.port()
);
drop(reserved_listener);
let topic = "".to_string();
let pub_socket = bind_pub_socket(&endpoint).await.unwrap();
let token = dynamo_runtime::CancellationToken::new();
let next_event_id = Arc::new(AtomicU64::new(0));
let listener_handle = tokio::spawn({
let token = token.clone();
start_zmq_listener(endpoint.to_string(), topic, 1, tx, token, 4, next_event_id)
});
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
let seq: u64 = 77;
let events = vec![RawKvEvent::BlockStored {
block_hashes: vec![BlockHashValue::Unsigned(42)],
parent_block_hash: None,
token_ids: vec![0, 1, 2, 3],
block_size: 4,
medium: None,
lora_name: None,
block_mm_infos: None,
is_eagle: None,
}];
let batch = KvEventBatch {
ts: 0.0,
events,
data_parallel_rank: Some(1),
};
let payload = Bytes::from(rmps::to_vec(&batch).unwrap());
let frames = vec![
Bytes::from("").to_vec(),
Bytes::from(seq.to_be_bytes().to_vec()).to_vec(),
payload.clone().to_vec(),
];
send_multipart(&pub_socket, frames).await.unwrap();
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
let event = rx.try_recv().expect("no message received").event;
let KvCacheEventData::Stored(KvCacheStoreData {
parent_hash,
blocks,
}) = event.data
else {
panic!("expected KvCacheStoreData");
};
assert!(parent_hash.is_none());
assert_eq!(blocks.len(), 1);
assert_eq!(blocks[0].block_hash.0, 42);
token.cancel();
let _ = listener_handle.await;
}
#[tokio::test]
async fn test_start_zmq_listener_connects_before_publisher_bind() {
let (tx, mut rx) = mpsc::unbounded_channel::<PlacementEvent>();
let reserved_listener = reserve_open_port();
let endpoint = format!(
"tcp://127.0.0.1:{}",
reserved_listener
.local_addr()
.expect("failed to read reserved listener address")
.port()
);
drop(reserved_listener);
let topic = String::new();
let token = dynamo_runtime::CancellationToken::new();
let next_event_id = Arc::new(AtomicU64::new(0));
let listener_handle = tokio::spawn({
let token = token.clone();
let endpoint = endpoint.clone();
start_zmq_listener(endpoint, topic, 1, tx, token, 4, next_event_id)
});
tokio::time::sleep(tokio::time::Duration::from_millis(150)).await;
let pub_socket = bind_pub_socket(&endpoint).await.unwrap();
let batch = KvEventBatch {
ts: 0.0,
events: vec![RawKvEvent::BlockStored {
block_hashes: vec![BlockHashValue::Unsigned(64)],
parent_block_hash: None,
token_ids: vec![4, 5, 6, 7],
block_size: 4,
medium: None,
lora_name: None,
block_mm_infos: None,
is_eagle: None,
}],
data_parallel_rank: Some(0),
};
let payload = rmps::to_vec(&batch).unwrap();
for _ in 0..5 {
send_multipart(
&pub_socket,
vec![Vec::new(), 12u64.to_be_bytes().to_vec(), payload.clone()],
)
.await
.unwrap();
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
}
let event = tokio::time::timeout(tokio::time::Duration::from_secs(5), rx.recv())
.await
.expect("timed out waiting for listener event")
.expect("listener channel closed")
.event;
let KvCacheEventData::Stored(KvCacheStoreData { blocks, .. }) = event.data else {
panic!("expected KvCacheStoreData");
};
assert_eq!(blocks[0].block_hash.0, 64);
token.cancel();
let _ = listener_handle.await;
}
fn reserve_open_port() -> std::net::TcpListener {
std::net::TcpListener::bind("127.0.0.1:0").expect("failed to bind probe listener")
}
#[tokio::test]
async fn test_distributed_kvindexer_recovery_from_outage() {
let worker_1_id = 1u64;
let block_size = 4u32;
let token = CancellationToken::new();
let (worker_component, worker_published) = MockComponent::new();
let local_indexer_1 = Arc::new(LocalKvIndexer::new(
token.clone(),
block_size,
Arc::new(KvIndexerMetrics::new_unregistered()),
100, ));
let (worker_tx, worker_rx) = mpsc::unbounded_channel::<PlacementEvent>();
tokio::spawn(start_event_processor(
worker_component,
worker_1_id,
token.clone(),
worker_rx,
Some(local_indexer_1.clone()),
Some(10), ));
let router_indexer = Arc::new(KvIndexer::new(
token.clone(),
block_size,
Arc::new(KvIndexerMetrics::new_unregistered()),
));
let event_1 = KvCacheEvent {
event_id: 1,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: vec![
KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(100),
tokens_hash: LocalBlockHash(200),
mm_extra_info: None,
},
KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(101),
tokens_hash: LocalBlockHash(201),
mm_extra_info: None,
},
],
}),
dp_rank: 0,
};
worker_tx
.send(local_gpu_event(worker_1_id, event_1.clone()))
.unwrap();
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
let (subject, bytes) = {
let published = worker_published.lock().unwrap();
assert_eq!(published.len(), 1, "Worker should have published 1 event");
(published[0].0.clone(), published[0].1.clone())
}; assert_eq!(subject, KV_EVENT_SUBJECT);
let router_event: RouterEvent = rmp_serde::from_slice(&bytes).unwrap();
router_indexer
.event_sender()
.send(router_event)
.await
.unwrap();
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
let get_workers_tx = router_indexer.get_workers_sender();
let mut router_has_worker = false;
for _ in 0..20 {
let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
get_workers_tx
.send(GetWorkersRequest { resp: resp_tx })
.await
.unwrap();
let workers: Vec<u64> = resp_rx.await.unwrap();
if workers.contains(&worker_1_id) {
router_has_worker = true;
break;
}
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
}
assert!(
router_has_worker,
"Router should see worker 1 after normal operation"
);
match local_indexer_1.get_events_in_id_range(Some(1), None).await {
WorkerKvQueryResponse::Events { events, .. } => {
assert_eq!(events.len(), 1, "Local indexer should buffer 1 event");
}
other => panic!("Expected buffered events, got {other:?}"),
}
let event_2 = KvCacheEvent {
event_id: 2,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: vec![
KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(100), tokens_hash: LocalBlockHash(200),
mm_extra_info: None,
},
KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(102), tokens_hash: LocalBlockHash(202),
mm_extra_info: None,
},
],
}),
dp_rank: 0,
};
worker_tx
.send(local_gpu_event(worker_1_id, event_2.clone()))
.unwrap(); tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
{
let published = worker_published.lock().unwrap();
assert_eq!(
published.len(),
2,
"Worker should have published 2 events total"
);
}
match local_indexer_1.get_events_in_id_range(Some(1), None).await {
WorkerKvQueryResponse::Events { events, .. } => {
assert_eq!(
events.len(),
2,
"Local indexer should have both events during outage"
);
}
other => panic!("Expected buffered events, got {other:?}"),
}
let block_hashes_2 = vec![LocalBlockHash(200), LocalBlockHash(202)];
let overlap = router_indexer
.find_matches(block_hashes_2.clone())
.await
.unwrap();
let router_overlap = overlap
.scores
.get(&dynamo_kv_router::protocols::WorkerWithDpRank::from_worker_id(worker_1_id))
.copied()
.unwrap_or(0);
assert_eq!(
router_overlap, 1,
"Router should only see 1 shared block (not the new block from event_2)"
);
let last_known_id = 1u64; let response = local_indexer_1
.get_events_in_id_range(Some(last_known_id + 1), None)
.await;
let missed_events = match response {
dynamo_kv_router::indexer::WorkerKvQueryResponse::Events { events: e, .. } => e,
dynamo_kv_router::indexer::WorkerKvQueryResponse::TreeDump { events: e, .. } => e,
dynamo_kv_router::indexer::WorkerKvQueryResponse::Error(message) => {
panic!("Unexpected error response: {message}")
}
other => panic!("Unexpected response: {:?}", other),
};
assert_eq!(
missed_events.len(),
1,
"Should get 1 missed event (event_2 with id=2)"
);
for router_event in missed_events {
router_indexer
.event_sender()
.send(router_event)
.await
.unwrap();
}
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
let overlap = router_indexer.find_matches(block_hashes_2).await.unwrap();
let router_overlap_after = overlap
.scores
.get(&dynamo_kv_router::protocols::WorkerWithDpRank::from_worker_id(worker_1_id))
.copied()
.unwrap_or(0);
assert_eq!(
router_overlap_after, 2,
"Router should now see both blocks after recovery"
);
token.cancel();
}
}
#[cfg(test)]
mod test_event_dedup_filter {
use super::*;
fn store_data(hashes: &[u64]) -> KvCacheStoreData {
KvCacheStoreData {
parent_hash: None,
blocks: hashes
.iter()
.map(|&h| KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(h),
tokens_hash: LocalBlockHash(h * 10),
mm_extra_info: None,
})
.collect(),
}
}
fn remove_data(hashes: &[u64]) -> KvCacheRemoveData {
KvCacheRemoveData {
block_hashes: hashes
.iter()
.map(|&h| ExternalSequenceBlockHash(h))
.collect(),
}
}
#[test]
fn stores_track_refcounts_for_removes() {
let mut filter = EventDedupFilter::new();
let data = store_data(&[1, 2, 3]);
filter.track_store(0, &data);
filter.track_store(0, &data);
let result = filter.filter_remove(0, remove_data(&[1, 2, 3]));
assert!(result.is_none());
let result = filter.filter_remove(0, remove_data(&[1, 2, 3]));
assert!(result.is_some());
assert_eq!(result.unwrap().block_hashes.len(), 3);
}
#[test]
fn duplicate_removes_are_filtered() {
let mut filter = EventDedupFilter::new();
filter.track_store(0, &store_data(&[1]));
filter.track_store(0, &store_data(&[1]));
let result = filter.filter_remove(0, remove_data(&[1]));
assert!(result.is_none());
let result = filter.filter_remove(0, remove_data(&[1]));
assert!(result.is_some());
assert_eq!(result.unwrap().block_hashes.len(), 1);
}
#[test]
fn store_remove_store_cycle() {
let mut filter = EventDedupFilter::new();
filter.track_store(0, &store_data(&[1]));
let result = filter.filter_remove(0, remove_data(&[1]));
assert!(result.is_some());
filter.track_store(0, &store_data(&[1]));
let result = filter.filter_remove(0, remove_data(&[1]));
assert!(result.is_some());
}
#[test]
fn clear_resets_all_ranks() {
let mut filter = EventDedupFilter::new();
filter.track_store(0, &store_data(&[1, 2]));
filter.track_store(0, &store_data(&[1, 2]));
filter.track_store(1, &store_data(&[1, 2]));
filter.track_store(1, &store_data(&[1, 2]));
filter.clear();
let result = filter.filter_remove(0, remove_data(&[1]));
assert!(result.is_some());
let result = filter.filter_remove(1, remove_data(&[1]));
assert!(result.is_some());
}
#[test]
fn mixed_blocks_in_single_remove() {
let mut filter = EventDedupFilter::new();
filter.track_store(0, &store_data(&[1]));
filter.track_store(0, &store_data(&[1]));
filter.track_store(0, &store_data(&[2]));
filter.track_store(0, &store_data(&[3]));
filter.track_store(0, &store_data(&[3]));
let result = filter.filter_remove(0, remove_data(&[1, 2, 3]));
assert!(result.is_some());
let result = result.unwrap();
assert_eq!(result.block_hashes.len(), 1);
assert_eq!(result.block_hashes[0], ExternalSequenceBlockHash(2));
}
#[test]
fn same_hash_on_different_ranks_are_independent() {
let mut filter = EventDedupFilter::new();
filter.track_store(0, &store_data(&[1]));
filter.track_store(0, &store_data(&[1]));
filter.track_store(1, &store_data(&[1]));
let result = filter.filter_remove(1, remove_data(&[1]));
assert!(result.is_some());
let result = filter.filter_remove(0, remove_data(&[1]));
assert!(result.is_none());
let result = filter.filter_remove(0, remove_data(&[1]));
assert!(result.is_some());
}
}
#[cfg(all(test, feature = "integration"))]
mod test_integration_publisher {
use super::*;
use crate::kv_router::KV_METRICS_SUBJECT;
use dynamo_kv_router::protocols::ActiveLoad;
use dynamo_runtime::distributed_test_utils::create_test_drt_async;
use dynamo_runtime::transports::event_plane::EventSubscriber;
#[tokio::test]
#[ignore] async fn test_metrics_publishing_behavior() -> Result<()> {
let drt = create_test_drt_async().await;
let namespace = drt.namespace("ns2001".to_string())?;
let mut subscriber = EventSubscriber::for_namespace(&namespace, KV_METRICS_SUBJECT)
.await
.unwrap()
.typed::<ActiveLoad>();
let publisher = WorkerMetricsPublisher::new().unwrap();
let worker_id = 1234;
publisher.start_nats_metrics_publishing(namespace.clone(), worker_id);
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
for i in 0..10 {
let value = (i * 100) as u64;
publisher.publish(None, None, Some(value)).unwrap();
tokio::time::sleep(tokio::time::Duration::from_micros(100)).await;
}
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
let result =
tokio::time::timeout(tokio::time::Duration::from_millis(500), subscriber.next())
.await
.unwrap();
let (_envelope, event) = result.unwrap().unwrap(); assert_eq!(event.worker_id, worker_id);
assert_eq!(event.active_decode_blocks, None); assert_eq!(event.active_prefill_tokens, None); assert_eq!(event.kv_used_blocks, Some(900));
let no_msg =
tokio::time::timeout(tokio::time::Duration::from_millis(50), subscriber.next()).await;
assert!(no_msg.is_err(), "Expected no more messages, but found one");
for _ in 0..10 {
publisher.publish(None, None, Some(900)).unwrap(); tokio::time::sleep(tokio::time::Duration::from_micros(100)).await;
}
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
let no_msg =
tokio::time::timeout(tokio::time::Duration::from_millis(50), subscriber.next()).await;
assert!(
no_msg.is_err(),
"Expected no messages when load metrics don't change"
);
drt.shutdown();
Ok(())
}
}
#[cfg(test)]
mod batching_state_tests {
use super::*;
#[test]
fn test_batching_state_default() {
let state = BatchingState::new();
assert!(!state.has_pending(), "Default state should have no pending");
assert!(
state.pending_removed.is_none(),
"Default pending_removed should be None"
);
assert!(
state.pending_stored.is_none(),
"Default pending_stored should be None"
);
}
#[test]
fn test_batching_state_new() {
let state = BatchingState::new();
let elapsed = state.last_flush_time.elapsed();
assert!(
elapsed < Duration::from_secs(1),
"new() should create state with flush time set to approximately now"
);
}
#[test]
fn test_batching_state_pending_removed() {
let mut state = BatchingState::new();
assert!(!state.has_pending(), "Should not have pending initially");
state.pending_removed = Some(KvCacheRemoveData {
block_hashes: vec![],
});
assert!(
state.has_pending(),
"Should have pending after setting pending_removed"
);
}
#[test]
fn test_batching_state_pending_stored() {
let mut state = BatchingState::new();
assert!(!state.has_pending(), "Should not have pending initially");
state.pending_stored = Some(KvCacheStoreData {
parent_hash: None,
blocks: vec![],
});
assert!(
state.has_pending(),
"Should have pending after setting pending_stored"
);
}
#[test]
fn test_batching_state_timeout() {
let mut state = BatchingState::new();
state.record_flush_time();
let remaining_before = state.remaining_timeout(10);
assert!(
remaining_before.as_millis() > 0,
"Should have remaining time initially"
);
let remaining_zero = state.remaining_timeout(0);
assert_eq!(
remaining_zero.as_millis(),
0,
"0 timeout should return zero"
);
}
#[test]
fn test_batching_state_record_flush_time() {
let mut state = BatchingState::new();
let initial_time = state.last_flush_time;
state.record_flush_time();
assert!(
state.last_flush_time >= initial_time,
"record_flush_time should update the time"
);
}
#[test]
fn test_batching_state_remaining_timeout() {
let mut state = BatchingState::new();
state.record_flush_time();
let remaining = state.remaining_timeout(10);
assert!(
remaining.as_millis() > 0,
"Should have remaining time initially"
);
let remaining_zero = state.remaining_timeout(0);
assert_eq!(
remaining_zero,
Duration::ZERO,
"0 timeout should return zero"
);
}
#[test]
fn test_batching_state_accumulate_removed() {
let mut state = BatchingState::new();
let first = KvCacheRemoveData {
block_hashes: vec![ExternalSequenceBlockHash(1), ExternalSequenceBlockHash(2)],
};
state.pending_removed = Some(first);
if let Some(ref mut pending) = state.pending_removed {
pending
.block_hashes
.extend(vec![ExternalSequenceBlockHash(3)]);
}
let pending = state.pending_removed.as_ref().unwrap();
assert_eq!(
pending.block_hashes.len(),
3,
"Should have accumulated 3 block hashes"
);
}
#[test]
fn test_batching_state_accumulate_stored() {
let mut state = BatchingState::new();
let block1 = KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(1),
tokens_hash: LocalBlockHash(100),
mm_extra_info: None,
};
let first = KvCacheStoreData {
parent_hash: Some(ExternalSequenceBlockHash(0)),
blocks: vec![block1],
};
state.pending_stored = Some(first);
let block2 = KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(2),
tokens_hash: LocalBlockHash(200),
mm_extra_info: None,
};
if let Some(ref mut pending) = state.pending_stored {
pending.blocks.extend(vec![block2]);
}
let pending = state.pending_stored.as_ref().unwrap();
assert_eq!(pending.blocks.len(), 2, "Should have accumulated 2 blocks");
}
}
#[cfg(test)]
mod event_processor_tests {
use super::*;
use std::sync::{Arc, Mutex};
use tokio_util::sync::CancellationToken;
#[derive(Debug, Clone)]
struct MockPublisher {
events: Arc<Mutex<Vec<RouterEvent>>>,
}
impl MockPublisher {
fn new() -> Self {
Self {
events: Arc::new(Mutex::new(Vec::new())),
}
}
fn get_events(&self) -> Vec<RouterEvent> {
self.events.lock().unwrap().clone()
}
}
impl RouterEventSink for MockPublisher {
fn publish_event(&self, event: &RouterEvent) -> impl Future<Output = Result<()>> + Send {
self.events.lock().unwrap().push(event.clone());
async { Ok(()) }
}
}
fn local_gpu_event(event: KvCacheEvent) -> PlacementEvent {
PlacementEvent::local_gpu(1, event)
}
#[tokio::test]
async fn test_run_event_processor_loop_batches_removed_events_20() {
test_removed_events_batching(20, Some(10)).await; }
#[tokio::test]
async fn test_run_event_processor_loop_batches_removed_events_10() {
test_removed_events_batching(10, Some(10)).await; }
#[tokio::test]
async fn test_run_event_processor_loop_batches_removed_events_5() {
test_removed_events_batching(5, Some(10)).await; }
#[tokio::test]
async fn test_run_event_processor_loop_batches_removed_events_3() {
test_removed_events_batching(3, Some(10)).await; }
async fn test_removed_events_batching(event_count: usize, timeout_ms: Option<u64>) {
let (tx, rx) = mpsc::unbounded_channel::<PlacementEvent>();
let publisher = MockPublisher::new();
let publisher_clone = publisher.clone();
let cancellation_token = CancellationToken::new();
let handle = tokio::spawn(async move {
run_event_processor_loop(
publisher_clone,
1,
cancellation_token,
rx,
None,
timeout_ms,
DEFAULT_MAX_BATCH_BLOCKS,
)
.await
});
for i in 0..event_count {
let event = KvCacheEvent {
event_id: i as u64,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: vec![ExternalSequenceBlockHash(i as u64)],
}),
dp_rank: 0,
};
tx.send(local_gpu_event(event)).unwrap();
tokio::task::yield_now().await;
}
tokio::time::sleep(tokio::time::Duration::from_millis(
timeout_ms.unwrap_or(0) + 1,
))
.await;
drop(tx);
handle.await.unwrap();
let events = publisher.get_events();
assert!(
!events.is_empty(),
"Should have received at least one event"
);
assert!(
events.len() <= 2,
"With long timeout ({timeout_ms:?}), all {event_count} events should batch into at most 2 output events (got {})",
events.len()
);
let total_hashes: usize = events
.iter()
.map(|e| {
if let KvCacheEventData::Removed(data) = &e.event.data {
data.block_hashes.len()
} else {
0
}
})
.sum();
assert_eq!(
total_hashes, event_count,
"All {} block hashes should be accounted for",
event_count
);
}
#[tokio::test]
async fn test_run_event_processor_loop_batches_stored_events_20() {
test_stored_events_batching(20, Some(100)).await; }
#[tokio::test]
async fn test_run_event_processor_loop_batches_stored_events_10() {
test_stored_events_batching(10, Some(100)).await; }
#[tokio::test]
async fn test_run_event_processor_loop_batches_stored_events_5() {
test_stored_events_batching(5, Some(100)).await; }
#[tokio::test]
async fn test_run_event_processor_loop_batches_stored_events_3() {
test_stored_events_batching(3, Some(100)).await; }
async fn test_stored_events_batching(event_count: usize, timeout_ms: Option<u64>) {
let (tx, rx) = mpsc::unbounded_channel::<PlacementEvent>();
let publisher = MockPublisher::new();
let publisher_clone = publisher.clone();
let cancellation_token = CancellationToken::new();
let handle = tokio::spawn(async move {
run_event_processor_loop(
publisher_clone,
1,
cancellation_token,
rx,
None,
timeout_ms,
DEFAULT_MAX_BATCH_BLOCKS,
)
.await
});
for i in 0..event_count {
let parent_hash = if i == 0 {
Some(ExternalSequenceBlockHash(0)) } else {
Some(ExternalSequenceBlockHash((i - 1) as u64)) };
let event = KvCacheEvent {
event_id: i as u64,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash,
blocks: vec![KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(i as u64),
tokens_hash: LocalBlockHash(i as u64 * 100),
mm_extra_info: None,
}],
}),
dp_rank: 0,
};
tx.send(local_gpu_event(event)).unwrap();
tokio::time::sleep(tokio::time::Duration::from_micros(100)).await;
}
tokio::time::sleep(tokio::time::Duration::from_millis(2)).await;
drop(tx);
handle.await.unwrap();
let events = publisher.get_events();
assert!(
!events.is_empty(),
"Should have received at least one event"
);
assert!(
events.len() <= 2,
"With long timeout ({timeout_ms:?}) and sequential parent hashes, all {event_count} events should batch into at most 2 output events (got {})",
events.len()
);
if events.len() == 2 {
if let KvCacheEventData::Stored(data) = &events[0].event.data {
assert_eq!(
data.blocks.len(),
1,
"If 2 events, first event should have 1 block (got {})",
data.blocks.len()
);
} else {
panic!("Expected Stored event");
}
}
let total_blocks: usize = events
.iter()
.map(|e| {
if let KvCacheEventData::Stored(data) = &e.event.data {
data.blocks.len()
} else {
0
}
})
.sum();
assert_eq!(
total_blocks, event_count,
"All {} blocks should be accounted for",
event_count
);
}
#[tokio::test]
async fn test_run_event_processor_loop_non_sequential_flush() {
let timeout_ms = Some(100);
let (tx, rx) = mpsc::unbounded_channel::<PlacementEvent>();
let publisher = MockPublisher::new();
let publisher_clone = publisher.clone();
let cancellation_token = CancellationToken::new();
let handle = tokio::spawn(async move {
run_event_processor_loop(
publisher_clone,
1,
cancellation_token,
rx,
None,
timeout_ms,
DEFAULT_MAX_BATCH_BLOCKS,
)
.await
});
for i in 0..3 {
let event = KvCacheEvent {
event_id: i as u64,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: Some(ExternalSequenceBlockHash((i + 1) as u64 * 100)),
blocks: vec![KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(i as u64),
tokens_hash: LocalBlockHash(i as u64 * 100),
mm_extra_info: None,
}],
}),
dp_rank: 0,
};
tx.send(local_gpu_event(event)).unwrap();
}
drop(tx);
handle.await.unwrap();
let events = publisher.get_events();
assert!(!events.is_empty(), "Should have received events");
assert_eq!(
events.len(),
3,
"Non-sequential events should trigger flush, resulting in 3 separate events"
);
let total_blocks: usize = events
.iter()
.map(|e| {
if let KvCacheEventData::Stored(data) = &e.event.data {
data.blocks.len()
} else {
0
}
})
.sum();
assert_eq!(total_blocks, 3, "All 3 blocks should be accounted for");
}
#[tokio::test]
async fn test_run_event_processor_loop_reused_parent_hash_breaks_chain() {
let timeout_ms = Some(100);
let (tx, rx) = mpsc::unbounded_channel::<PlacementEvent>();
let publisher = MockPublisher::new();
let publisher_clone = publisher.clone();
let cancellation_token = CancellationToken::new();
let handle = tokio::spawn(async move {
run_event_processor_loop(
publisher_clone,
1,
cancellation_token,
rx,
None,
timeout_ms,
DEFAULT_MAX_BATCH_BLOCKS,
)
.await
});
tx.send(local_gpu_event(KvCacheEvent {
event_id: 0,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None,
blocks: vec![KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(1),
tokens_hash: LocalBlockHash(100),
mm_extra_info: None,
}],
}),
dp_rank: 0,
}))
.unwrap();
tokio::task::yield_now().await;
tx.send(local_gpu_event(KvCacheEvent {
event_id: 1,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: Some(ExternalSequenceBlockHash(1)),
blocks: vec![KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(2),
tokens_hash: LocalBlockHash(200),
mm_extra_info: None,
}],
}),
dp_rank: 0,
}))
.unwrap();
tokio::task::yield_now().await;
tx.send(local_gpu_event(KvCacheEvent {
event_id: 2,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: Some(ExternalSequenceBlockHash(1)),
blocks: vec![KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(3),
tokens_hash: LocalBlockHash(300),
mm_extra_info: None,
}],
}),
dp_rank: 0,
}))
.unwrap();
tokio::time::sleep(tokio::time::Duration::from_millis(2)).await;
drop(tx);
handle.await.unwrap();
let events = publisher.get_events();
assert_eq!(
events.len(),
2,
"Reused parent hash should flush the current batch before starting a new one"
);
if let KvCacheEventData::Stored(data) = &events[0].event.data {
assert_eq!(
data.blocks.len(),
2,
"First batch should keep the valid chain"
);
assert_eq!(
data.parent_hash, None,
"First batch should preserve the original root parent"
);
} else {
panic!("Expected first event to be Stored");
}
if let KvCacheEventData::Stored(data) = &events[1].event.data {
assert_eq!(
data.blocks.len(),
1,
"Second batch should contain only the inconsistent event"
);
assert_eq!(
data.parent_hash,
Some(ExternalSequenceBlockHash(1)),
"Second batch should preserve the reused parent hash"
);
} else {
panic!("Expected second event to be Stored");
}
}
#[tokio::test]
async fn test_run_event_processor_loop_no_batching_with_slow_input_0ms() {
test_no_batching_with_slow_input(None).await; }
#[tokio::test]
async fn test_run_event_processor_loop_no_batching_with_slow_input_0_1ms() {
test_no_batching_with_slow_input(Some(1)).await; }
#[tokio::test]
async fn test_run_event_processor_loop_no_batching_with_slow_input_0_2ms() {
test_no_batching_with_slow_input(Some(2)).await; }
async fn test_no_batching_with_slow_input(timeout_ms: Option<u64>) {
let (tx, rx) = mpsc::unbounded_channel::<PlacementEvent>();
let publisher = MockPublisher::new();
let publisher_clone = publisher.clone();
let cancellation_token = CancellationToken::new();
let handle = tokio::spawn(async move {
run_event_processor_loop(
publisher_clone,
1,
cancellation_token,
rx,
None,
timeout_ms,
DEFAULT_MAX_BATCH_BLOCKS,
)
.await
});
for i in 0..5 {
let event = KvCacheEvent {
event_id: i as u64,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: vec![ExternalSequenceBlockHash(i as u64)],
}),
dp_rank: 0,
};
tx.send(local_gpu_event(event)).unwrap();
tokio::time::sleep(tokio::time::Duration::from_millis(2)).await;
}
tokio::time::sleep(tokio::time::Duration::from_millis(2)).await;
drop(tx);
handle.await.unwrap();
let events = publisher.get_events();
assert!(!events.is_empty(), "Should have received events");
assert!(
events.len() >= 3,
"With slow input (2ms delay) and timeout={timeout_ms:?}, should have at least 3 separate events (got {})",
events.len()
);
let total_hashes: usize = events
.iter()
.map(|e| {
if let KvCacheEventData::Removed(data) = &e.event.data {
data.block_hashes.len()
} else {
0
}
})
.sum();
assert_eq!(
total_hashes, 5,
"All 5 block hashes should be accounted for"
);
}
#[tokio::test]
async fn test_event_type_switching_causes_flush() {
let timeout_ms = Some(100);
let (tx, rx) = mpsc::unbounded_channel::<PlacementEvent>();
let publisher = MockPublisher::new();
let publisher_clone = publisher.clone();
let cancellation_token = CancellationToken::new();
let handle = tokio::spawn(async move {
run_event_processor_loop(
publisher_clone,
1,
cancellation_token,
rx,
None,
timeout_ms,
DEFAULT_MAX_BATCH_BLOCKS,
)
.await
});
tx.send(local_gpu_event(KvCacheEvent {
event_id: 0,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: vec![ExternalSequenceBlockHash(0)],
}),
dp_rank: 0,
}))
.unwrap();
tokio::time::sleep(tokio::time::Duration::from_micros(100)).await;
tx.send(local_gpu_event(KvCacheEvent {
event_id: 1,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: Some(ExternalSequenceBlockHash(0)),
blocks: vec![KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(1),
tokens_hash: LocalBlockHash(100),
mm_extra_info: None,
}],
}),
dp_rank: 0,
}))
.unwrap();
tokio::time::sleep(tokio::time::Duration::from_millis(2)).await;
drop(tx);
handle.await.unwrap();
let events = publisher.get_events();
assert_eq!(
events.len(),
2,
"Switching from Removed to Stored should cause immediate flush, resulting in 2 separate events"
);
}
#[tokio::test]
async fn test_dp_rank_change_causes_flush() {
let timeout_ms = Some(100);
let (tx, rx) = mpsc::unbounded_channel::<PlacementEvent>();
let publisher = MockPublisher::new();
let publisher_clone = publisher.clone();
let cancellation_token = CancellationToken::new();
let handle = tokio::spawn(async move {
run_event_processor_loop(
publisher_clone,
1,
cancellation_token,
rx,
None,
timeout_ms,
DEFAULT_MAX_BATCH_BLOCKS,
)
.await
});
for i in 0..3 {
tx.send(local_gpu_event(KvCacheEvent {
event_id: i as u64,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: vec![ExternalSequenceBlockHash(i as u64)],
}),
dp_rank: 0,
}))
.unwrap();
tokio::task::yield_now().await;
}
for i in 3..6 {
tx.send(local_gpu_event(KvCacheEvent {
event_id: i as u64,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: vec![ExternalSequenceBlockHash(i as u64)],
}),
dp_rank: 1,
}))
.unwrap();
tokio::task::yield_now().await;
}
tokio::time::sleep(tokio::time::Duration::from_millis(2)).await;
drop(tx);
handle.await.unwrap();
let events = publisher.get_events();
assert_eq!(
events.len(),
2,
"dp_rank change should cause immediate flush, resulting in 2 separate events"
);
let total_hashes: usize = events
.iter()
.map(|e| {
if let KvCacheEventData::Removed(data) = &e.event.data {
data.block_hashes.len()
} else {
0
}
})
.sum();
assert_eq!(
total_hashes, 6,
"All 6 block hashes should be accounted for"
);
assert_eq!(
events[0].event.dp_rank, 0,
"First batch should have dp_rank=0"
);
assert_eq!(
events[1].event.dp_rank, 1,
"Second batch should have dp_rank=1"
);
}
#[tokio::test]
async fn test_flushed_events_have_correct_metadata() {
let timeout_ms = Some(100);
let (tx, rx) = mpsc::unbounded_channel::<PlacementEvent>();
let publisher = MockPublisher::new();
let publisher_clone = publisher.clone();
let cancellation_token = CancellationToken::new();
let handle = tokio::spawn(async move {
run_event_processor_loop(
publisher_clone,
1,
cancellation_token,
rx,
None,
timeout_ms,
DEFAULT_MAX_BATCH_BLOCKS,
)
.await
});
for i in 0..3 {
tx.send(local_gpu_event(KvCacheEvent {
event_id: 10 + i as u64,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: vec![ExternalSequenceBlockHash(i as u64)],
}),
dp_rank: 0,
}))
.unwrap();
tokio::task::yield_now().await;
}
for i in 0..2 {
tx.send(local_gpu_event(KvCacheEvent {
event_id: 20 + i as u64,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: vec![ExternalSequenceBlockHash((i + 3) as u64)],
}),
dp_rank: 1,
}))
.unwrap();
tokio::task::yield_now().await;
}
tokio::time::sleep(tokio::time::Duration::from_millis(2)).await;
drop(tx);
handle.await.unwrap();
let events = publisher.get_events();
assert_eq!(
events.len(),
2,
"Should have 2 events (one per dp_rank batch)"
);
assert_eq!(
events[0].event.dp_rank, 0,
"First batch should have dp_rank=0"
);
assert_eq!(
events[0].event.event_id, 1,
"First batch should have monotonic event_id=1"
);
assert_eq!(
events[1].event.dp_rank, 1,
"Second batch should have dp_rank=1"
);
assert_eq!(
events[1].event.event_id, 2,
"Second batch should have monotonic event_id=2"
);
}
#[tokio::test]
async fn test_first_event_after_idle_flushes_immediately_then_batches() {
let timeout_ms = Some(50);
let (tx, rx) = mpsc::unbounded_channel::<PlacementEvent>();
let publisher = MockPublisher::new();
let publisher_clone = publisher.clone();
let cancellation_token = CancellationToken::new();
let handle = tokio::spawn(async move {
run_event_processor_loop(
publisher_clone,
1,
cancellation_token,
rx,
None,
timeout_ms,
DEFAULT_MAX_BATCH_BLOCKS,
)
.await
});
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
for i in 0..3 {
tx.send(local_gpu_event(KvCacheEvent {
event_id: i as u64,
data: KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: vec![ExternalSequenceBlockHash(i as u64)],
}),
dp_rank: 0,
}))
.unwrap();
tokio::task::yield_now().await;
}
tokio::time::sleep(tokio::time::Duration::from_millis(60)).await;
drop(tx);
handle.await.unwrap();
let events = publisher.get_events();
assert_eq!(
events.len(),
2,
"First event should flush immediately (stale), remaining 2 should batch"
);
let first_len = if let KvCacheEventData::Removed(data) = &events[0].event.data {
data.block_hashes.len()
} else {
0
};
let second_len = if let KvCacheEventData::Removed(data) = &events[1].event.data {
data.block_hashes.len()
} else {
0
};
assert_eq!(first_len, 1, "First event should have 1 hash");
assert_eq!(second_len, 2, "Second event (batched) should have 2 hashes");
}
#[tokio::test]
async fn test_stored_events_with_dp_rank_change_correct_metadata() {
let timeout_ms = Some(100);
let (tx, rx) = mpsc::unbounded_channel::<PlacementEvent>();
let publisher = MockPublisher::new();
let publisher_clone = publisher.clone();
let cancellation_token = CancellationToken::new();
let handle = tokio::spawn(async move {
run_event_processor_loop(
publisher_clone,
1,
cancellation_token,
rx,
None,
timeout_ms,
DEFAULT_MAX_BATCH_BLOCKS,
)
.await
});
tx.send(local_gpu_event(KvCacheEvent {
event_id: 100,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: Some(ExternalSequenceBlockHash(0)),
blocks: vec![KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(1),
tokens_hash: LocalBlockHash(100),
mm_extra_info: None,
}],
}),
dp_rank: 0,
}))
.unwrap();
tokio::task::yield_now().await;
tx.send(local_gpu_event(KvCacheEvent {
event_id: 101,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: Some(ExternalSequenceBlockHash(1)),
blocks: vec![KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(2),
tokens_hash: LocalBlockHash(200),
mm_extra_info: None,
}],
}),
dp_rank: 0,
}))
.unwrap();
tokio::task::yield_now().await;
tx.send(local_gpu_event(KvCacheEvent {
event_id: 200,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: Some(ExternalSequenceBlockHash(0)),
blocks: vec![KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(100),
tokens_hash: LocalBlockHash(1000),
mm_extra_info: None,
}],
}),
dp_rank: 1,
}))
.unwrap();
tokio::time::sleep(tokio::time::Duration::from_millis(2)).await;
drop(tx);
handle.await.unwrap();
let events = publisher.get_events();
assert_eq!(
events.len(),
2,
"Should have 2 events (one per dp_rank batch)"
);
assert_eq!(
events[0].event.dp_rank, 0,
"First batch should have dp_rank=0"
);
assert_eq!(
events[0].event.event_id, 1,
"First batch should have monotonic event_id=1"
);
assert_eq!(
events[1].event.dp_rank, 1,
"Second batch should have dp_rank=1"
);
assert_eq!(
events[1].event.event_id, 2,
"Second batch should have monotonic event_id=2"
);
if let KvCacheEventData::Stored(data) = &events[0].event.data {
assert_eq!(data.blocks.len(), 2, "First batch should have 2 blocks");
} else {
panic!("Expected Stored event");
}
if let KvCacheEventData::Stored(data) = &events[1].event.data {
assert_eq!(data.blocks.len(), 1, "Second batch should have 1 block");
} else {
panic!("Expected Stored event");
}
}
#[tokio::test]
async fn test_batch_parent_hash_preserved_when_extending() {
let timeout_ms = Some(100);
let (tx, rx) = mpsc::unbounded_channel::<PlacementEvent>();
let publisher = MockPublisher::new();
let publisher_clone = publisher.clone();
let cancellation_token = CancellationToken::new();
let handle = tokio::spawn(async move {
run_event_processor_loop(
publisher_clone,
1,
cancellation_token,
rx,
None,
timeout_ms,
DEFAULT_MAX_BATCH_BLOCKS,
)
.await
});
tx.send(local_gpu_event(KvCacheEvent {
event_id: 0,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: None, blocks: vec![KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(1),
tokens_hash: LocalBlockHash(100),
mm_extra_info: None,
}],
}),
dp_rank: 0,
}))
.unwrap();
tokio::task::yield_now().await;
tx.send(local_gpu_event(KvCacheEvent {
event_id: 1,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: Some(ExternalSequenceBlockHash(1)), blocks: vec![KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(2),
tokens_hash: LocalBlockHash(200),
mm_extra_info: None,
}],
}),
dp_rank: 0,
}))
.unwrap();
tokio::task::yield_now().await;
tx.send(local_gpu_event(KvCacheEvent {
event_id: 2,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: Some(ExternalSequenceBlockHash(2)),
blocks: vec![KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(3),
tokens_hash: LocalBlockHash(300),
mm_extra_info: None,
}],
}),
dp_rank: 0,
}))
.unwrap();
tokio::time::sleep(tokio::time::Duration::from_millis(2)).await;
drop(tx);
handle.await.unwrap();
let events = publisher.get_events();
assert_eq!(
events.len(),
1,
"All 3 sequential events should batch into 1"
);
if let KvCacheEventData::Stored(data) = &events[0].event.data {
assert_eq!(data.blocks.len(), 3, "Batch should have 3 blocks");
assert_eq!(
data.parent_hash, None,
"Batch parent_hash should remain None (from first event), NOT overwritten by subsequent events"
);
} else {
panic!("Expected Stored event");
}
}
}