use std::collections::HashMap;
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use crate::cache::radix_cache::{NodeId, RadixCache};
use crate::common::kv_cache_trace;
use crate::common::protocols::KvEventPublishers;
use dynamo_kv_router::protocols::{
ExternalSequenceBlockHash, KvCacheEvent, KvCacheEventData, KvCacheRemoveData, KvCacheStoreData,
KvCacheStoredBlockData,
};
pub struct AllocResult {
pub prefix_len: usize,
pub kv_indices: Vec<usize>,
pub last_node: NodeId,
}
pub struct SglangKvManager {
cache: RadixCache,
kv_event_publishers: KvEventPublishers,
dp_rank: u32,
next_event_id: u64,
idx_to_block_hash: HashMap<usize, ExternalSequenceBlockHash>,
block_hash_refcounts: HashMap<ExternalSequenceBlockHash, usize>,
}
impl SglangKvManager {
pub fn new(
total_tokens: usize,
page_size: usize,
kv_event_publishers: KvEventPublishers,
dp_rank: u32,
) -> Self {
Self {
cache: RadixCache::new(total_tokens, page_size),
kv_event_publishers,
dp_rank,
next_event_id: 0,
idx_to_block_hash: HashMap::new(),
block_hash_refcounts: HashMap::new(),
}
}
pub fn cache(&self) -> &RadixCache {
&self.cache
}
pub fn cache_mut(&mut self) -> &mut RadixCache {
&mut self.cache
}
pub fn allocate_for_request(&mut self, token_ids: &[u64]) -> Option<AllocResult> {
let (prefix_len, last_node) = self.cache.match_prefix(token_ids);
let new_tokens = token_ids.len() - prefix_len;
let prefix_indices = self.collect_path_indices(last_node);
let new_indices = self.cache.token_pool.allocate(new_tokens)?;
let mut kv_indices = prefix_indices;
kv_indices.extend_from_slice(&new_indices);
self.cache.inc_lock_ref(last_node);
let parent_hash = kv_indices
.get(prefix_len.wrapping_sub(1))
.and_then(|&idx| self.idx_to_block_hash.get(&idx).copied());
self.publish_stored_event(&token_ids[prefix_len..], &new_indices, parent_hash);
self.log_trace("allocation", new_tokens);
Some(AllocResult {
prefix_len,
kv_indices,
last_node,
})
}
pub fn allocate_after_prefix(
&mut self,
token_ids: &[u64],
prefix_len: usize,
prefix_indices: &[usize],
last_node: NodeId,
) -> Option<AllocResult> {
let new_tokens = token_ids.len().saturating_sub(prefix_len);
let new_indices = self.cache.token_pool.allocate(new_tokens)?;
let mut kv_indices = prefix_indices[..prefix_len].to_vec();
kv_indices.extend_from_slice(&new_indices);
self.cache.inc_lock_ref(last_node);
let parent_hash = kv_indices
.get(prefix_len.wrapping_sub(1))
.and_then(|&idx| self.idx_to_block_hash.get(&idx).copied());
self.publish_stored_event(&token_ids[prefix_len..], &new_indices, parent_hash);
self.log_trace("allocation", new_tokens);
Some(AllocResult {
prefix_len,
kv_indices,
last_node,
})
}
pub fn cache_finished_req(
&mut self,
token_ids: &[u64],
kv_indices: &[usize],
last_node: NodeId,
) {
self.cache.insert(token_ids, kv_indices);
self.cache.dec_lock_ref(last_node);
}
pub fn cache_unfinished_req(
&mut self,
token_ids: &[u64],
kv_indices: &[usize],
last_node: NodeId,
) -> NodeId {
self.cache.insert(token_ids, kv_indices);
let (_, new_last_node) = self.cache.match_prefix(token_ids);
self.cache.dec_lock_ref(last_node);
self.cache.inc_lock_ref(new_last_node);
new_last_node
}
pub fn allocate_decode_token(&mut self, last_idx: Option<usize>) -> Option<usize> {
let indices = self.cache.token_pool.allocate(1)?;
let idx = indices[0];
let parent_hash = last_idx.and_then(|i| self.idx_to_block_hash.get(&i).copied());
self.publish_stored_event(&[], &[idx], parent_hash);
self.log_trace("allocation", 1);
Some(idx)
}
pub fn free_request(&mut self, last_node: NodeId) {
self.cache.dec_lock_ref(last_node);
}
pub fn free_indices(&mut self, indices: &[usize]) {
if indices.is_empty() {
return;
}
self.cache.token_pool.free(indices);
self.publish_removed_event(indices);
self.log_trace("free", indices.len());
}
fn collect_path_indices(&self, last_node: NodeId) -> Vec<usize> {
if last_node == self.cache.root() {
return Vec::new();
}
let mut path = Vec::new();
let mut current = last_node;
loop {
let node = self.cache.node(current);
if node.parent.is_none() {
break;
}
path.push(current);
current = node.parent.unwrap();
}
path.reverse();
let mut indices = Vec::new();
for node_id in path {
indices.extend_from_slice(&self.cache.node(node_id).value);
}
indices
}
pub fn evict(&mut self, num_tokens: usize) {
let (evicted, evicted_indices) = self.cache.evict(num_tokens);
if !evicted_indices.is_empty() {
self.publish_removed_event(&evicted_indices);
}
self.log_trace("eviction", evicted);
}
fn log_trace(&self, event: &str, num_tokens: usize) {
kv_cache_trace::log_sglang_trace(&kv_cache_trace::SglangCacheState {
event,
dp_rank: self.dp_rank,
num_tokens,
page_size: self.cache.page_size(),
available_tokens: self.cache.available_tokens(),
evictable_tokens: self.cache.evictable_size,
protected_tokens: self.cache.protected_size,
total_tokens: self.cache.total_tokens(),
});
}
fn publish_stored_event(
&mut self,
token_ids: &[u64],
indices: &[usize],
parent_hash: Option<ExternalSequenceBlockHash>,
) {
if indices.is_empty() {
return;
}
let mut computed_blocks = Vec::with_capacity(indices.len());
let mut running_hash = parent_hash.map_or(0u64, |h| h.0);
for (i, &idx) in indices.iter().enumerate() {
let token = token_ids.get(i).copied().unwrap_or(idx as u64);
let token_bytes = token.to_le_bytes();
let tokens_hash = dynamo_kv_router::protocols::compute_block_hash(&token_bytes);
let mut hasher = DefaultHasher::new();
running_hash.hash(&mut hasher);
tokens_hash.0.hash(&mut hasher);
running_hash = hasher.finish();
let block_hash = ExternalSequenceBlockHash(running_hash);
self.idx_to_block_hash.insert(idx, block_hash);
*self.block_hash_refcounts.entry(block_hash).or_default() += 1;
computed_blocks.push(KvCacheStoredBlockData {
block_hash,
tokens_hash,
mm_extra_info: None,
});
}
if self.kv_event_publishers.is_empty() {
return;
}
let first_new = computed_blocks.iter().position(|block| {
self.block_hash_refcounts
.get(&block.block_hash)
.copied()
.unwrap_or_default()
== 1
});
let Some(first_new) = first_new else {
return;
};
let parent_hash = if first_new == 0 {
parent_hash
} else {
Some(computed_blocks[first_new - 1].block_hash)
};
let blocks = computed_blocks.into_iter().skip(first_new).collect();
let event = KvCacheEvent {
event_id: self.next_event_id,
data: KvCacheEventData::Stored(KvCacheStoreData {
parent_hash,
blocks,
}),
dp_rank: self.dp_rank,
};
self.next_event_id += 1;
if let Err(e) = self.kv_event_publishers.publish(event, None) {
tracing::warn!("Failed to publish SGLang KV event: {e}");
}
}
fn publish_removed_event(&mut self, evicted_indices: &[usize]) {
if self.kv_event_publishers.is_empty() {
return;
}
let mut block_hashes = Vec::new();
for &idx in evicted_indices {
let Some(block_hash) = self.idx_to_block_hash.remove(&idx) else {
continue;
};
let Some(refcount) = self.block_hash_refcounts.get_mut(&block_hash) else {
continue;
};
if *refcount > 1 {
*refcount -= 1;
continue;
}
self.block_hash_refcounts.remove(&block_hash);
block_hashes.push(block_hash);
}
if block_hashes.is_empty() {
return;
}
let event = KvCacheEvent {
event_id: self.next_event_id,
data: KvCacheEventData::Removed(KvCacheRemoveData { block_hashes }),
dp_rank: self.dp_rank,
};
self.next_event_id += 1;
if let Err(e) = self.kv_event_publishers.publish(event, None) {
tracing::warn!("Failed to publish SGLang KV remove event: {e}");
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::sync::Mutex;
use crate::common::protocols::KvCacheEventSink;
struct MockSink {
events: Mutex<Vec<KvCacheEvent>>,
}
impl MockSink {
fn new() -> Self {
Self {
events: Mutex::new(Vec::new()),
}
}
fn event_count(&self) -> usize {
self.events.lock().unwrap().len()
}
fn clone_events(&self) -> Vec<KvCacheEvent> {
self.events.lock().unwrap().clone()
}
}
impl KvCacheEventSink for MockSink {
fn publish(&self, event: KvCacheEvent) -> anyhow::Result<()> {
self.events.lock().unwrap().push(event);
Ok(())
}
}
#[test]
fn test_allocate_cache_miss() {
let mut mgr = SglangKvManager::new(100, 1, KvEventPublishers::default(), 0);
let result = mgr.allocate_for_request(&[1, 2, 3, 4, 5]).unwrap();
assert_eq!(result.prefix_len, 0);
assert_eq!(result.kv_indices.len(), 5);
assert_eq!(mgr.cache().token_pool.available(), 95);
}
#[test]
fn test_allocate_cache_hit() {
let mut mgr = SglangKvManager::new(100, 1, KvEventPublishers::default(), 0);
let r1 = mgr.allocate_for_request(&[1, 2, 3, 4, 5]).unwrap();
assert_eq!(r1.kv_indices.len(), 5); mgr.cache_finished_req(&[1, 2, 3, 4, 5], &r1.kv_indices, r1.last_node);
let r2 = mgr.allocate_for_request(&[1, 2, 3, 4, 5, 6, 7]).unwrap();
assert_eq!(r2.prefix_len, 5);
assert_eq!(r2.kv_indices.len(), 7); assert_eq!(mgr.cache().token_pool.available(), 93); }
#[test]
fn test_free_request_without_caching() {
let mut mgr = SglangKvManager::new(100, 1, KvEventPublishers::default(), 0);
let result = mgr.allocate_for_request(&[1, 2, 3]).unwrap();
mgr.free_request(result.last_node);
assert_eq!(mgr.cache().protected_size, 0);
}
#[test]
fn test_event_publishing() {
let sink = Arc::new(MockSink::new());
let mut mgr =
SglangKvManager::new(100, 1, KvEventPublishers::new(Some(sink.clone()), None), 0);
let r = mgr.allocate_for_request(&[1, 2, 3]).unwrap();
assert_eq!(sink.event_count(), 1);
mgr.cache_finished_req(&[1, 2, 3], &r.kv_indices, r.last_node);
let r2 = mgr.allocate_for_request(&[1, 2, 3]).unwrap();
assert_eq!(r2.prefix_len, 3);
assert_eq!(sink.event_count(), 1); }
#[test]
fn test_duplicate_logical_blocks_publish_once_and_remove_once() {
let sink = Arc::new(MockSink::new());
let mut mgr =
SglangKvManager::new(100, 1, KvEventPublishers::new(Some(sink.clone()), None), 0);
let req1 = mgr.allocate_for_request(&[1, 2, 3]).unwrap();
let req2 = mgr.allocate_for_request(&[1, 2, 3]).unwrap();
let events = sink.clone_events();
assert_eq!(events.len(), 1);
let KvCacheEventData::Stored(store) = &events[0].data else {
panic!("expected stored event");
};
assert_eq!(store.blocks.len(), 3);
mgr.free_indices(&req1.kv_indices);
assert_eq!(sink.event_count(), 1);
mgr.free_indices(&req2.kv_indices);
let events = sink.clone_events();
assert_eq!(events.len(), 2);
let KvCacheEventData::Removed(remove) = &events[1].data else {
panic!("expected removed event");
};
assert_eq!(remove.block_hashes.len(), 3);
}
#[test]
fn test_allocate_oom() {
let mut mgr = SglangKvManager::new(3, 1, KvEventPublishers::default(), 0);
let _r = mgr.allocate_for_request(&[1, 2, 3]).unwrap();
let result = mgr.allocate_for_request(&[4, 5, 6]);
assert!(result.is_none());
}
#[test]
fn test_chunked_prefill_parent_hash() {
let sink = Arc::new(MockSink::new());
let mut mgr =
SglangKvManager::new(32, 1, KvEventPublishers::new(Some(sink.clone()), None), 0);
let tokens = [11, 22, 33, 44, 55, 66];
let chunk1_len = 3;
let chunk2_len = 6;
let alloc1 = mgr.allocate_for_request(&tokens[..chunk1_len]).unwrap();
let new_last =
mgr.cache_unfinished_req(&tokens[..chunk1_len], &alloc1.kv_indices, alloc1.last_node);
let alloc2 = mgr.allocate_for_request(&tokens[..chunk2_len]).unwrap();
mgr.free_request(new_last);
let events = sink.events.lock().unwrap();
assert_eq!(events.len(), 2, "expected two stored events");
let KvCacheEventData::Stored(store1) = &events[0].data else {
panic!("expected first event to be Stored");
};
let KvCacheEventData::Stored(store2) = &events[1].data else {
panic!("expected second event to be Stored");
};
assert!(
store1.parent_hash.is_none(),
"first chunk should start from the root"
);
let last_block_hash = store1
.blocks
.last()
.expect("first chunk should store at least one block")
.block_hash;
assert_eq!(
store2.parent_hash,
Some(last_block_hash),
"second chunk should chain from the last block of chunk 1"
);
assert_eq!(
store2.blocks.len(),
chunk2_len - chunk1_len,
"second chunk should only emit new blocks"
);
assert_eq!(
alloc2.prefix_len, chunk1_len,
"second chunk should reuse the cached partial prefix"
);
}
}