use crate::cache::HashCache;
use crate::common::kv_cache_trace;
use crate::common::protocols::{KvEventPublishers, MoveBlock, PrefillCost};
use crate::common::sequence::ActiveSequence;
use dynamo_kv_router::protocols::{
ExternalSequenceBlockHash, KvCacheEvent, KvCacheEventData, KvCacheRemoveData, KvCacheStoreData,
KvCacheStoredBlockData, LocalBlockHash,
};
use dynamo_tokens::blocks::UniqueBlock;
use dynamo_tokens::{BlockHash, SequenceHash};
use rustc_hash::FxHashMap;
pub struct KvManager {
cache: HashCache,
block_size: usize,
kv_event_publishers: KvEventPublishers,
dp_rank: u32,
next_event_id: u64,
}
impl KvManager {
pub fn new(max_capacity: usize, block_size: usize) -> Self {
Self::new_with_event_sink(max_capacity, block_size, KvEventPublishers::default(), 0)
}
pub fn new_with_event_sink(
max_capacity: usize,
block_size: usize,
kv_event_publishers: KvEventPublishers,
dp_rank: u32,
) -> Self {
debug_assert!(max_capacity > 0, "max_capacity must be > 0");
if !kv_event_publishers.is_empty() {
tracing::info!(
"KvManager initialized with event sink for DP rank {dp_rank} with block_size {block_size}"
);
}
KvManager {
cache: HashCache::new(max_capacity),
block_size,
kv_event_publishers,
dp_rank,
next_event_id: 0,
}
}
fn publish_kv_event(
&mut self,
full_blocks: Vec<SequenceHash>,
local_hashes: &[BlockHash],
parent_hash: Option<u64>,
is_store: bool,
token_ids: Option<Vec<Vec<u32>>>,
) {
if full_blocks.is_empty() {
return;
}
kv_cache_trace::log_vllm_trace(
if is_store { "allocation" } else { "eviction" },
self.dp_rank,
self.block_size,
self.cache.num_active(),
self.cache.num_inactive(),
self.cache.max_capacity(),
);
if self.kv_event_publishers.is_empty() {
return;
}
let event_data = if is_store {
let num_blocks = full_blocks.len();
let local_hashes_slice = &local_hashes[local_hashes
.len()
.checked_sub(num_blocks)
.expect("local hashes fewer than stored blocks")..];
KvCacheEventData::Stored(KvCacheStoreData {
parent_hash: parent_hash.map(ExternalSequenceBlockHash),
blocks: full_blocks
.into_iter()
.zip(local_hashes_slice.iter())
.map(|(global_hash, local_hash)| KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(global_hash),
tokens_hash: LocalBlockHash(*local_hash),
mm_extra_info: None,
})
.collect(),
})
} else {
KvCacheEventData::Removed(KvCacheRemoveData {
block_hashes: full_blocks
.into_iter()
.map(ExternalSequenceBlockHash)
.collect(),
})
};
let event_id = self.next_event_id;
self.next_event_id += 1;
let event = KvCacheEvent {
event_id,
data: event_data,
dp_rank: self.dp_rank,
};
if let Err(e) = self
.kv_event_publishers
.publish(event, token_ids.as_deref())
{
tracing::warn!("Failed to publish KV event: {e}");
}
}
pub fn process(&mut self, event: &MoveBlock) -> usize {
match event {
MoveBlock::Use(hashes, local_hashes, token_ids, parent) => {
let mut blocks_stored = Vec::<u64>::new();
let mut stored_token_ids: Option<Vec<Vec<u32>>> =
token_ids.as_ref().map(|_| Vec::new());
let mut parent_block: Option<&UniqueBlock> = parent.as_ref();
let mut allocated = 0;
for (i, hash) in hashes.iter().enumerate() {
if self.cache.contains_active(hash) {
self.cache.increment_ref(hash);
parent_block = Some(hash);
allocated += 1;
continue;
}
if self.cache.reactivate(hash) {
parent_block = Some(hash);
allocated += 1;
continue;
}
if self.cache.is_at_capacity() {
let Some(evicted) = self.cache.evict_inactive() else {
break;
};
tracing::trace!(
"Evicting block from inactive pool: {evicted:?}, dp_rank={}",
self.dp_rank
);
if let UniqueBlock::FullBlock(evicted_full_block) = evicted {
self.publish_kv_event(vec![evicted_full_block], &[], None, false, None);
}
}
self.cache.insert_active(hash.clone(), 1);
allocated += 1;
if let UniqueBlock::FullBlock(stored_full_block) = hash {
blocks_stored.push(*stored_full_block);
if let Some(ref mut stids) = stored_token_ids {
stids.push(token_ids.as_ref().unwrap()[i].clone());
}
}
}
let parent_hash = match parent_block {
None => None,
Some(UniqueBlock::FullBlock(block)) => Some(*block),
Some(UniqueBlock::PartialBlock(_)) => panic!("parent block cannot be partial"),
};
self.publish_kv_event(
blocks_stored,
local_hashes,
parent_hash,
true,
stored_token_ids,
);
return allocated;
}
MoveBlock::Destroy(hashes) => {
let mut blocks_destroyed = Vec::<u64>::new();
for hash in hashes.iter() {
self.cache.remove_active(hash).unwrap();
if let UniqueBlock::FullBlock(destroyed_full_block) = hash {
blocks_destroyed.push(*destroyed_full_block);
}
}
self.publish_kv_event(blocks_destroyed, &[], None, false, None);
}
MoveBlock::Deref(hashes) => {
for hash in hashes.iter() {
if let Some(ref_count) = self.cache.get_active_ref_count(hash) {
if ref_count == 0 {
panic!("Negative reference count would be encountered after Deref.");
}
if ref_count == 1 {
self.cache.deactivate(hash);
} else {
self.cache.decrement_ref(hash);
}
}
}
}
MoveBlock::Promote(uuid, hash, parent_hash, local_hash, promote_token_ids) => {
let uuid_block = UniqueBlock::PartialBlock(*uuid);
let hash_block = UniqueBlock::FullBlock(*hash);
assert_eq!(
self.cache.remove_active(&uuid_block),
Some(1),
"uuid_block {uuid_block:?} should exist and be unique with ref_count=1"
);
let hash_ref_count = self.cache.get_active_ref_count(&hash_block);
let is_new = if hash_ref_count.is_some() {
false
} else {
!self.cache.remove_inactive(&hash_block)
};
self.cache
.insert_active(hash_block, hash_ref_count.unwrap_or(0) + 1);
if is_new {
self.publish_kv_event(
vec![*hash],
&[*local_hash],
*parent_hash,
true,
promote_token_ids.as_ref().map(|t| vec![t.clone()]),
);
}
}
}
1
}
pub fn probe_new_blocks(&self, blocks: &[UniqueBlock]) -> usize {
blocks
.iter()
.filter(|&block| !self.cache.contains(block))
.count()
}
pub fn current_capacity(&self) -> usize {
self.cache.current_capacity()
}
pub fn current_capacity_perc(&self) -> f64 {
self.cache.current_capacity() as f64 / self.cache.max_capacity() as f64
}
pub fn num_active_blocks(&self) -> usize {
self.cache.num_active()
}
pub fn get_active_perc(&self) -> f64 {
self.cache.num_active() as f64 / self.cache.max_capacity() as f64
}
pub fn num_inactive_blocks(&self) -> usize {
self.cache.num_inactive()
}
pub fn get_inactive_blocks(&self) -> Vec<&UniqueBlock> {
self.cache.inactive_keys().collect()
}
pub fn get_active_blocks(&self) -> Vec<&UniqueBlock> {
self.cache.active_keys().collect()
}
pub fn max_capacity(&self) -> usize {
self.cache.max_capacity()
}
pub fn block_size(&self) -> usize {
self.block_size
}
pub fn dp_rank(&self) -> u32 {
self.dp_rank
}
pub fn active_blocks(&self) -> &FxHashMap<UniqueBlock, usize> {
self.cache.active_blocks()
}
pub fn get_prefill_cost(&self, sequence: &ActiveSequence) -> PrefillCost {
let seq_blocks = sequence.unique_blocks();
let mut overlap_blocks = 0;
for block in seq_blocks {
if !self.cache.contains(block) {
break;
}
overlap_blocks += 1;
}
let new_blocks = seq_blocks.len() - overlap_blocks;
let cached_tokens = (overlap_blocks * self.block_size).min(sequence.num_input_tokens());
let new_tokens = sequence.num_input_tokens() - cached_tokens;
PrefillCost {
new_blocks,
new_tokens,
cached_tokens,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use crate::common::protocols::KvCacheEventSink;
#[test]
fn test_failure_on_max_capacity() {
let mut manager = KvManager::new(10, 16);
fn use_blocks(manager: &mut KvManager, ids: Vec<u64>) -> usize {
let blocks: Vec<_> = ids.iter().map(|&id| UniqueBlock::FullBlock(id)).collect();
let hashes: Vec<_> = ids.into_iter().collect();
manager.process(&MoveBlock::Use(blocks, hashes, None, None))
}
let response = use_blocks(&mut manager, (0..10).collect());
assert_eq!(response, 10, "Expected all 10 blocks allocated");
assert_eq!(manager.current_capacity(), 10);
let response = use_blocks(&mut manager, vec![10]);
assert_eq!(
response, 0,
"Expected 0 blocks allocated when exceeding max capacity"
);
}
#[test]
fn test_block_lifecycle_stringent() {
let mut manager = KvManager::new(10, 16);
fn use_blocks(manager: &mut KvManager, ids: Vec<u64>) {
let blocks: Vec<_> = ids.iter().map(|&id| UniqueBlock::FullBlock(id)).collect();
let hashes: Vec<_> = ids.into_iter().collect();
manager.process(&MoveBlock::Use(blocks, hashes, None, None));
}
fn destroy_blocks(manager: &mut KvManager, ids: Vec<u64>) {
let blocks = ids.into_iter().map(UniqueBlock::FullBlock).collect();
manager.process(&MoveBlock::Destroy(blocks));
}
fn deref_blocks(manager: &mut KvManager, ids: Vec<u64>) {
let blocks = ids.into_iter().map(UniqueBlock::FullBlock).collect();
manager.process(&MoveBlock::Deref(blocks));
}
fn assert_active_blocks(manager: &KvManager, expected_blocks: &[(u64, usize)]) {
assert_eq!(
manager.active_blocks().len(),
expected_blocks.len(),
"Active blocks count doesn't match expected"
);
for &(id, ref_count) in expected_blocks {
let block = UniqueBlock::FullBlock(id);
assert!(
manager.active_blocks().contains_key(&block),
"Block {id} not found in active blocks",
);
assert_eq!(
manager.active_blocks().get(&block),
Some(&ref_count),
"Block {id} has wrong reference count",
);
}
}
fn assert_inactive_blocks(
manager: &KvManager,
expected_size: usize,
expected_blocks: &[u64],
) {
let inactive_blocks = manager.get_inactive_blocks();
let inactive_blocks_count = manager.num_inactive_blocks();
assert_eq!(
inactive_blocks_count, expected_size,
"Inactive blocks count doesn't match expected"
);
for &id in expected_blocks {
let block = UniqueBlock::FullBlock(id);
assert!(
inactive_blocks.iter().any(|&b| *b == block),
"Block {id} not found in inactive blocks",
);
}
}
use_blocks(&mut manager, (0..5).collect());
use_blocks(&mut manager, vec![0, 1, 5, 6]);
assert_active_blocks(
&manager,
&[(0, 2), (1, 2), (2, 1), (3, 1), (4, 1), (5, 1), (6, 1)],
);
destroy_blocks(&mut manager, vec![4]);
deref_blocks(&mut manager, vec![0, 1, 2, 3]);
assert_inactive_blocks(&manager, 2, &[3, 2]);
assert_active_blocks(&manager, &[(0, 1), (1, 1), (5, 1), (6, 1)]);
destroy_blocks(&mut manager, vec![6]);
deref_blocks(&mut manager, vec![0, 1, 5]);
assert_inactive_blocks(&manager, 5, &[0, 1, 2, 3, 5]);
assert_active_blocks(&manager, &[]);
use_blocks(&mut manager, vec![0, 1, 2, 7, 8, 9]);
assert_inactive_blocks(&manager, 2, &[3, 5]);
assert_active_blocks(&manager, &[(0, 1), (1, 1), (2, 1), (7, 1), (8, 1), (9, 1)]);
let blocks_to_check: Vec<UniqueBlock> = vec![0, 1, 2, 3, 4]
.into_iter()
.map(UniqueBlock::FullBlock)
.collect();
assert_eq!(manager.probe_new_blocks(&blocks_to_check), 1);
use_blocks(&mut manager, vec![10, 11, 12]);
assert_inactive_blocks(&manager, 1, &[5]);
use_blocks(&mut manager, vec![13]);
}
#[test]
fn test_chunked_prefill_parent_hash() {
use std::sync::Mutex;
use crate::common::sequence::ActiveSequence;
#[derive(Default)]
struct CapturingSink {
events: Mutex<Vec<KvCacheEvent>>,
}
impl KvCacheEventSink for CapturingSink {
fn publish(&self, event: KvCacheEvent) -> anyhow::Result<()> {
self.events.lock().unwrap().push(event);
Ok(())
}
}
let block_size = 64;
let tokens: Vec<u32> = (0..512).collect(); let mut seq = ActiveSequence::new(tokens, 100, Some(block_size), true, false);
let sink = Arc::new(CapturingSink::default());
let mut manager = KvManager::new_with_event_sink(
256,
block_size,
KvEventPublishers::new(Some(sink.clone() as _), None),
0,
);
let signal = seq.prepare_allocation(256).unwrap();
manager.process(&signal);
seq.commit_allocation(256);
let signal = seq.prepare_allocation(512).unwrap();
manager.process(&signal);
seq.commit_allocation(512);
let events = sink.events.lock().unwrap();
assert_eq!(events.len(), 2, "expected two store events");
let KvCacheEventData::Stored(ref store1) = events[0].data else {
panic!("expected store event");
};
assert!(
store1.parent_hash.is_none(),
"first chunk should have no parent"
);
let KvCacheEventData::Stored(ref store2) = events[1].data else {
panic!("expected store event");
};
let expected_parent = seq.unique_blocks()[3].clone();
let UniqueBlock::FullBlock(expected_hash) = expected_parent else {
panic!("expected full block");
};
assert_eq!(
store2.parent_hash,
Some(ExternalSequenceBlockHash(expected_hash)),
"second chunk's parent should be block 3's seq_hash"
);
}
#[test]
fn test_repreempt_after_partial_recompute_only_frees_reallocated_blocks() {
let mut seq = ActiveSequence::new((0..6).collect(), 16, Some(4), true, false);
let mut manager = KvManager::new(16, 4);
let signal = seq.take_creation_signal().unwrap();
assert_eq!(manager.process(&signal), 2);
for _ in 0..3 {
let signals = seq.generate();
for signal in &signals {
manager.process(signal);
}
if seq.generated_tokens() < seq.max_output_tokens() {
seq.commit_allocation(seq.len());
}
}
assert_eq!(manager.num_active_blocks(), 3);
let first_reset = seq.reset_with_signal();
for signal in &first_reset {
manager.process(signal);
}
assert_eq!(manager.num_active_blocks(), 0);
let prompt_only = seq.prepare_allocation(seq.num_input_tokens()).unwrap();
assert_eq!(manager.process(&prompt_only), 2);
seq.commit_allocation(seq.num_input_tokens());
assert_eq!(manager.num_active_blocks(), 2);
let second_reset = seq.reset_with_signal();
for signal in &second_reset {
manager.process(signal);
}
assert_eq!(manager.num_active_blocks(), 0);
}
}