use std::collections::{HashMap, HashSet};
use dynamo_kv_router::protocols::XXH3_SEED;
type LocalBlockHash = u64;
type SequenceHash = u64;
fn compute_local_block_hash(token_ids: &[u32]) -> LocalBlockHash {
let bytes: Vec<u8> = token_ids
.iter()
.flat_map(|&num| num.to_le_bytes())
.collect();
xxhash_rust::xxh3::xxh3_64_with_seed(&bytes, XXH3_SEED)
}
fn compute_sequence_hash(
parent_sequence_hash: Option<SequenceHash>,
current_block_hash: LocalBlockHash,
) -> SequenceHash {
match parent_sequence_hash {
None => {
current_block_hash
}
Some(parent_hash) => {
let combined = [parent_hash, current_block_hash];
let bytes: Vec<u8> = combined.iter().flat_map(|&num| num.to_le_bytes()).collect();
xxhash_rust::xxh3::xxh3_64_with_seed(&bytes, XXH3_SEED)
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
pub enum EventSource {
Vllm,
Trtllm,
Kvbm,
}
impl std::str::FromStr for EventSource {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"vllm" | "VLLM" | "GPU" => Ok(EventSource::Vllm),
"trtllm" | "TRTLLM" | "TensorRT-LLM" => Ok(EventSource::Trtllm),
"kvbm" | "KVBM" => Ok(EventSource::Kvbm),
_ => Err(format!("Unknown event source: {}", s)),
}
}
}
impl EventSource {
pub fn to_str(&self) -> &'static str {
match self {
EventSource::Vllm => "vllm",
EventSource::Trtllm => "trtllm",
EventSource::Kvbm => "kvbm",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum StorageTier {
Device, HostPinned, Disk, }
impl StorageTier {
pub fn from_vllm_medium(s: &str) -> Option<Self> {
match s {
"GPU" => Some(StorageTier::Device),
"CPU_TIER1" => Some(StorageTier::HostPinned),
"CPU_TIER2" => Some(StorageTier::Disk),
_ => None,
}
}
pub fn to_vllm_medium(&self) -> &'static str {
match self {
StorageTier::Device => "GPU",
StorageTier::HostPinned => "CPU_TIER1",
StorageTier::Disk => "CPU_TIER2",
}
}
pub fn to_str(&self) -> &'static str {
match self {
StorageTier::Device => "device",
StorageTier::HostPinned => "host_pinned",
StorageTier::Disk => "disk",
}
}
}
#[deprecated(note = "Use StorageTier instead")]
pub type StorageMedium = StorageTier;
#[derive(Debug, Clone)]
pub struct BlockMetadata {
pub sources: HashSet<EventSource>,
pub first_block_hash: String,
}
impl BlockMetadata {
pub fn new(source: EventSource, block_hash: String) -> Self {
let mut sources = HashSet::new();
sources.insert(source);
Self {
sources,
first_block_hash: block_hash,
}
}
pub fn exists_in_any_source(&self) -> bool {
!self.sources.is_empty()
}
pub fn add_source(&mut self, source: EventSource) -> bool {
self.sources.insert(source)
}
pub fn remove_source(&mut self, source: EventSource) -> bool {
self.sources.remove(&source)
}
}
#[derive(Debug, Clone)]
pub enum ConsolidatedEvent {
Store {
block_hash: String,
parent_hash: Option<String>,
token_ids: Vec<u32>,
block_size: usize,
lora_name: Option<String>,
source: String,
},
Remove {
block_hash: String,
source: String, },
ClearAll,
}
#[derive(Debug)]
pub struct CacheStatusTracker {
blocks: HashMap<SequenceHash, BlockMetadata>,
hash_mapping: HashMap<String, SequenceHash>,
event_queue: Vec<ConsolidatedEvent>,
}
impl Default for CacheStatusTracker {
fn default() -> Self {
Self::new()
}
}
impl CacheStatusTracker {
pub fn new() -> Self {
Self {
blocks: HashMap::new(),
hash_mapping: HashMap::new(),
event_queue: Vec::new(),
}
}
#[allow(clippy::too_many_arguments)]
pub fn handle_store(
&mut self,
block_hash: String,
source: EventSource,
token_ids: Vec<u32>,
parent_hash: Option<String>,
block_size: usize,
lora_name: Option<String>,
tier: Option<StorageTier>,
data_parallel_rank: Option<i32>,
) -> bool {
let local_block_hash = compute_local_block_hash(&token_ids);
let parent_sequence_hash = parent_hash
.as_ref()
.and_then(|ph| self.hash_mapping.get(ph).copied());
let sequence_hash = compute_sequence_hash(parent_sequence_hash, local_block_hash);
tracing::debug!(
"Computing sequence_hash for block: local_block_hash={}, parent_seq_hash={:?}, sequence_hash={}",
local_block_hash,
parent_sequence_hash,
sequence_hash
);
if let Some(metadata) = self.blocks.get_mut(&sequence_hash) {
let is_new_source = metadata.add_source(source);
self.hash_mapping.insert(block_hash.clone(), sequence_hash);
if is_new_source {
tracing::debug!(
"DEDUP: Block {} (seq_hash={}) added to source {:?} (already exists in {} source(s), {} tokens, external_hash={})\n Token IDs: {:?}",
&metadata.first_block_hash[..16.min(metadata.first_block_hash.len())],
sequence_hash,
source,
metadata.sources.len(),
token_ids.len(),
&block_hash[..16.min(block_hash.len())],
&token_ids
);
} else {
tracing::debug!(
"Block {} (seq_hash={}) already in source {:?}, external_hash={}\n Token IDs: {:?}",
&metadata.first_block_hash[..16.min(metadata.first_block_hash.len())],
sequence_hash,
source,
&block_hash[..16.min(block_hash.len())],
&token_ids
);
}
false
} else {
let metadata = BlockMetadata::new(source, block_hash.clone());
tracing::debug!(
"New block {} (seq_hash={}) stored in source {:?} (tier={:?}): {} tokens, block_size={}, parent={}, lora={:?}, dp_rank={:?}\n Token IDs: {:?}",
&block_hash[..16.min(block_hash.len())],
sequence_hash,
source,
tier,
token_ids.len(),
block_size,
parent_hash
.as_ref()
.map(|p| &p[..16.min(p.len())])
.unwrap_or("none"),
lora_name,
data_parallel_rank,
&token_ids
);
self.blocks.insert(sequence_hash, metadata);
self.hash_mapping.insert(block_hash.clone(), sequence_hash);
let resolved_parent_hash = parent_hash.and_then(|ph| {
self.hash_mapping.get(&ph).and_then(|&parent_seq_hash| {
self.blocks
.get(&parent_seq_hash)
.map(|parent_metadata| parent_metadata.first_block_hash.clone())
})
});
self.event_queue.push(ConsolidatedEvent::Store {
block_hash: block_hash.clone(),
parent_hash: resolved_parent_hash,
token_ids,
block_size,
lora_name,
source: source.to_str().to_string(),
});
tracing::debug!(
"Block {} (seq_hash={}) stored in first source {:?}, will publish STORE event (total tracked: {}, hash_mapping: {})",
block_hash,
sequence_hash,
source,
self.blocks.len(),
self.hash_mapping.len()
);
true
}
}
pub fn handle_remove(&mut self, block_hash: &str, source: EventSource) -> bool {
let sequence_hash = match self.hash_mapping.get(block_hash) {
Some(&hash) => hash,
None => {
tracing::warn!(
"Attempted to remove unknown block {} from source {:?} (not in hash_mapping)",
block_hash,
source
);
return false;
}
};
if let Some(metadata) = self.blocks.get_mut(&sequence_hash) {
let was_removed = metadata.remove_source(source);
if !was_removed {
tracing::warn!(
"Attempted to remove source {:?} from block {} but it wasn't present",
source,
block_hash
);
return false;
}
self.hash_mapping.remove(block_hash);
tracing::debug!(
"Removed hash_mapping entry for {} (hash_mapping size: {})",
block_hash,
self.hash_mapping.len()
);
if !metadata.exists_in_any_source() {
let first_block_hash = metadata.first_block_hash.clone();
self.blocks.remove(&sequence_hash);
let stray_count_before = self.hash_mapping.len();
self.hash_mapping
.retain(|_ext_hash, &mut seq_hash| seq_hash != sequence_hash);
let stray_count = stray_count_before - self.hash_mapping.len();
if stray_count > 0 {
tracing::warn!(
"Found {} stray hash_mapping entries for seq_hash={} after all sources removed - cleaned up (hash_mapping size now: {})",
stray_count,
sequence_hash,
self.hash_mapping.len()
);
}
self.event_queue.push(ConsolidatedEvent::Remove {
block_hash: first_block_hash.clone(),
source: source.to_str().to_string(),
});
tracing::debug!(
"Block {} (seq_hash={}) removed from last source {:?}, will publish REMOVE event (total tracked: {}, hash_mapping: {})",
first_block_hash,
sequence_hash,
source,
self.blocks.len(),
self.hash_mapping.len()
);
true
} else {
tracing::debug!(
"Block {} (seq_hash={}) removed from source {:?}, still in {} source(s): {:?} (hash_mapping: {})",
&metadata.first_block_hash[..16.min(metadata.first_block_hash.len())],
sequence_hash,
source,
metadata.sources.len(),
metadata.sources,
self.hash_mapping.len()
);
false
}
} else {
tracing::warn!(
"Attempted to remove block {} from source {:?} but block not tracked",
&block_hash[..16.min(block_hash.len())],
source
);
false
}
}
pub fn handle_clear_all(&mut self) {
let num_blocks = self.blocks.len();
tracing::debug!("Clearing all {} blocks from tracker", num_blocks);
self.blocks.clear();
self.hash_mapping.clear();
self.event_queue.push(ConsolidatedEvent::ClearAll);
}
pub fn drain_events(&mut self) -> Vec<ConsolidatedEvent> {
let events = std::mem::take(&mut self.event_queue);
if !events.is_empty() {
tracing::debug!(
"Draining {} pending kv event(s) for publishing",
events.len()
);
}
events
}
pub fn num_blocks(&self) -> usize {
self.blocks.len()
}
pub fn get_block_sources(&self, external_block_hash: &str) -> Option<&HashSet<EventSource>> {
let local_hash = self.hash_mapping.get(external_block_hash)?;
self.blocks.get(local_hash).map(|m| &m.sources)
}
#[deprecated(note = "Use get_block_sources instead")]
pub fn get_block_tiers(&self, block_hash: &str) -> Option<&HashSet<EventSource>> {
self.get_block_sources(block_hash)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_first_store_publishes() {
let mut tracker = CacheStatusTracker::new();
let should_publish = tracker.handle_store(
"hash1".to_string(),
EventSource::Vllm,
vec![1, 2, 3],
None,
3,
None,
Some(StorageTier::Device),
None, );
assert!(should_publish);
assert_eq!(tracker.num_blocks(), 1);
assert_eq!(tracker.drain_events().len(), 1);
}
#[test]
fn test_duplicate_store_no_publish() {
let mut tracker = CacheStatusTracker::new();
tracker.handle_store(
"hash1".to_string(),
EventSource::Vllm,
vec![1, 2, 3],
None,
3,
None,
Some(StorageTier::Device),
None,
);
tracker.drain_events();
let should_publish = tracker.handle_store(
"hash1".to_string(),
EventSource::Vllm,
vec![1, 2, 3],
None,
3,
None,
Some(StorageTier::Device),
None,
);
assert!(!should_publish);
assert_eq!(tracker.drain_events().len(), 0);
}
#[test]
fn test_multi_source_store() {
let mut tracker = CacheStatusTracker::new();
tracker.handle_store(
"vllm_hash1".to_string(), EventSource::Vllm,
vec![1, 2, 3],
None,
3,
None,
Some(StorageTier::Device),
None,
);
tracker.drain_events();
let should_publish = tracker.handle_store(
"kvbm_hash1".to_string(), EventSource::Kvbm,
vec![1, 2, 3], None,
3,
None,
Some(StorageTier::HostPinned),
None,
);
assert!(!should_publish);
#[allow(deprecated)]
let sources = tracker.get_block_tiers("vllm_hash1").unwrap();
assert_eq!(sources.len(), 2); }
#[test]
fn test_remove_from_single_source_publishes() {
let mut tracker = CacheStatusTracker::new();
tracker.handle_store(
"hash1".to_string(),
EventSource::Vllm,
vec![1, 2, 3],
None,
3,
None,
Some(StorageTier::Device),
None,
);
tracker.drain_events();
let should_publish = tracker.handle_remove("hash1", EventSource::Vllm);
assert!(should_publish);
assert_eq!(tracker.num_blocks(), 0);
let events = tracker.drain_events();
assert_eq!(events.len(), 1);
matches!(events[0], ConsolidatedEvent::Remove { .. });
}
#[test]
fn test_remove_from_multi_source_no_publish() {
let mut tracker = CacheStatusTracker::new();
tracker.handle_store(
"vllm_hash1".to_string(), EventSource::Vllm,
vec![1, 2, 3],
None,
3,
None,
Some(StorageTier::Device),
None,
);
tracker.handle_store(
"kvbm_hash1".to_string(), EventSource::Kvbm,
vec![1, 2, 3], None,
3,
None,
Some(StorageTier::HostPinned),
None,
);
tracker.drain_events();
let should_publish = tracker.handle_remove("vllm_hash1", EventSource::Vllm);
assert!(!should_publish);
assert_eq!(tracker.num_blocks(), 1);
assert_eq!(tracker.drain_events().len(), 0);
let should_publish = tracker.handle_remove("kvbm_hash1", EventSource::Kvbm);
assert!(should_publish);
assert_eq!(tracker.num_blocks(), 0);
}
#[test]
fn test_sequence_hash_first_block() {
let mut tracker = CacheStatusTracker::new();
let should_publish = tracker.handle_store(
"block1".to_string(),
EventSource::Vllm,
vec![1, 2, 3, 4],
None, 4,
None,
Some(StorageTier::Device),
None,
);
assert!(should_publish);
assert_eq!(tracker.num_blocks(), 1);
let events = tracker.drain_events();
assert_eq!(events.len(), 1);
}
#[test]
fn test_sequence_hash_with_parent() {
let mut tracker = CacheStatusTracker::new();
tracker.handle_store(
"block1".to_string(),
EventSource::Vllm,
vec![1, 2, 3, 4],
None,
4,
None,
Some(StorageTier::Device),
None,
);
tracker.drain_events();
let should_publish = tracker.handle_store(
"block2".to_string(),
EventSource::Vllm,
vec![5, 6, 7, 8],
Some("block1".to_string()), 4,
None,
Some(StorageTier::Device),
None,
);
assert!(should_publish);
assert_eq!(tracker.num_blocks(), 2);
}
#[test]
fn test_same_tokens_different_position_different_blocks() {
let mut tracker = CacheStatusTracker::new();
tracker.handle_store(
"block1".to_string(),
EventSource::Vllm,
vec![1, 2, 3, 4],
None,
4,
None,
Some(StorageTier::Device),
None,
);
tracker.drain_events();
let should_publish = tracker.handle_store(
"block2".to_string(),
EventSource::Vllm,
vec![1, 2, 3, 4], Some("block1".to_string()), 4,
None,
Some(StorageTier::Device),
None,
);
assert!(should_publish);
assert_eq!(tracker.num_blocks(), 2);
}
#[test]
fn test_clear_all() {
let mut tracker = CacheStatusTracker::new();
tracker.handle_store(
"block1".to_string(),
EventSource::Vllm,
vec![1, 2, 3, 4],
None,
4,
None,
Some(StorageTier::Device),
None,
);
tracker.handle_store(
"block2".to_string(),
EventSource::Kvbm,
vec![5, 6, 7, 8],
None,
4,
None,
Some(StorageTier::HostPinned),
None,
);
assert_eq!(tracker.num_blocks(), 2);
tracker.handle_clear_all();
assert_eq!(tracker.num_blocks(), 0);
let should_publish = tracker.handle_remove("block1", EventSource::Vllm);
assert!(!should_publish); }
#[test]
fn test_deduplication_across_sources_with_parent() {
let mut tracker = CacheStatusTracker::new();
tracker.handle_store(
"vllm_parent".to_string(),
EventSource::Vllm,
vec![1, 2, 3, 4],
None,
4,
None,
Some(StorageTier::Device),
None,
);
tracker.drain_events();
tracker.handle_store(
"vllm_child".to_string(),
EventSource::Vllm,
vec![5, 6, 7, 8],
Some("vllm_parent".to_string()),
4,
None,
Some(StorageTier::Device),
None,
);
tracker.drain_events();
let should_publish = tracker.handle_store(
"kvbm_child".to_string(), EventSource::Kvbm,
vec![5, 6, 7, 8], Some("vllm_parent".to_string()), 4,
None,
Some(StorageTier::HostPinned),
None,
);
assert!(!should_publish);
assert_eq!(tracker.num_blocks(), 2);
}
#[test]
fn test_remove_non_existent_block() {
let mut tracker = CacheStatusTracker::new();
let should_publish = tracker.handle_remove("non_existent", EventSource::Vllm);
assert!(!should_publish);
assert_eq!(tracker.num_blocks(), 0);
}
#[test]
fn test_compute_local_block_hash_deterministic() {
let tokens1 = vec![1, 2, 3, 4];
let tokens2 = vec![1, 2, 3, 4];
let tokens3 = vec![1, 2, 3, 5];
let hash1 = compute_local_block_hash(&tokens1);
let hash2 = compute_local_block_hash(&tokens2);
let hash3 = compute_local_block_hash(&tokens3);
assert_eq!(hash1, hash2);
assert_ne!(hash1, hash3);
}
#[test]
fn test_lora_name_round_trip_through_tracker() {
let mut tracker = CacheStatusTracker::new();
let should_publish = tracker.handle_store(
"hash_lora".to_string(),
EventSource::Vllm,
vec![1, 2, 3, 4],
None,
4,
Some("my-adapter".to_string()),
Some(StorageTier::Device),
None,
);
assert!(should_publish);
let events = tracker.drain_events();
assert_eq!(events.len(), 1);
match &events[0] {
ConsolidatedEvent::Store {
lora_name,
token_ids,
..
} => {
assert_eq!(lora_name.as_deref(), Some("my-adapter"));
assert_eq!(token_ids, &[1, 2, 3, 4]);
}
other => panic!("expected Store event, got: {:?}", other),
}
}
#[test]
fn test_lora_name_none_for_base_model() {
let mut tracker = CacheStatusTracker::new();
tracker.handle_store(
"hash_base".to_string(),
EventSource::Vllm,
vec![1, 2, 3, 4],
None,
4,
None,
Some(StorageTier::Device),
None,
);
let events = tracker.drain_events();
assert_eq!(events.len(), 1);
match &events[0] {
ConsolidatedEvent::Store { lora_name, .. } => {
assert!(lora_name.is_none());
}
other => panic!("expected Store event, got: {:?}", other),
}
}
#[test]
fn test_compute_sequence_hash_deterministic() {
let block_hash1 = compute_local_block_hash(&[1, 2, 3, 4]);
let block_hash2 = compute_local_block_hash(&[5, 6, 7, 8]);
let seq_hash1 = compute_sequence_hash(None, block_hash1);
assert_eq!(seq_hash1, block_hash1);
let seq_hash2_v1 = compute_sequence_hash(Some(seq_hash1), block_hash2);
let seq_hash2_v2 = compute_sequence_hash(Some(seq_hash1), block_hash2);
assert_eq!(seq_hash2_v1, seq_hash2_v2);
let different_parent = compute_local_block_hash(&[9, 10, 11, 12]);
let seq_hash2_different = compute_sequence_hash(Some(different_parent), block_hash2);
assert_ne!(seq_hash2_v1, seq_hash2_different);
}
}