use candle_core::{Device, Result};
use indexmap::IndexMap;
use itertools::Itertools;
use tracing::info;
use crate::{
kv_cache::RecurrentStateSnapshot, paged_attention::block_hash::BlockHash, pipeline::KvCache,
sequence::Sequence,
};
#[derive(PartialEq, Eq, Debug, Hash)]
struct Tokens(Vec<u32>);
impl Tokens {
fn shared_prefix_len(&self, other: &Self) -> usize {
self.0
.iter()
.zip(other.0.iter())
.take_while(|(a, b)| a == b)
.count()
}
}
impl From<Vec<u32>> for Tokens {
fn from(value: Vec<u32>) -> Self {
Self(value)
}
}
#[derive(Clone)]
struct CacheElement {
cache: Vec<Option<KvCache>>,
recurrent_snapshots: Option<Vec<RecurrentStateSnapshot>>,
audio_hashes: Option<Vec<u64>>,
image_hashes: Option<Vec<u64>>,
video_hashes: Option<Vec<u64>>,
}
impl CacheElement {
fn can_rewind_to(&self, len: usize) -> bool {
self.cache
.iter()
.flatten()
.all(|layer| layer.try_set_len(len).is_ok())
}
}
pub struct PrefixCacheManagerV2 {
caches: IndexMap<Tokens, CacheElement>,
paged_recurrent_caches: IndexMap<Vec<BlockHash>, Vec<RecurrentStateSnapshot>>,
n_on_device: usize,
no_prefix_cache: bool,
has_paged_attention: bool,
}
#[derive(Clone)]
pub enum MatchingCache {
Normal {
normal: Vec<Option<KvCache>>,
recurrent_snapshots: Option<Vec<RecurrentStateSnapshot>>,
images_to_keep: usize,
audios_to_keep: usize,
videos_to_keep: usize,
toks: Vec<u32>,
offset: usize,
},
}
impl PrefixCacheManagerV2 {
pub fn new(n_on_device: usize, no_prefix_cache: bool, has_paged_attention: bool) -> Self {
if !no_prefix_cache && !has_paged_attention {
info!("Prefix caching enabled (sequence-level, non-paged attention). Expect higher multi-turn throughput for both text and multimodal.");
}
PrefixCacheManagerV2 {
caches: IndexMap::new(),
paged_recurrent_caches: IndexMap::new(),
n_on_device,
no_prefix_cache,
has_paged_attention,
}
}
fn paged_recurrent_capacity(&self) -> usize {
self.n_on_device.max(1).saturating_mul(8)
}
pub fn add_sequence(
&mut self,
seq: &mut Sequence,
recurrent_snapshots: Option<Vec<RecurrentStateSnapshot>>,
) {
if self.no_prefix_cache {
return;
}
if !self.has_paged_attention {
let cache = seq.normal_cache().to_vec();
self.caches.insert(
seq.get_toks().to_vec().into(),
CacheElement {
cache,
recurrent_snapshots,
image_hashes: seq.image_hashes().map(|x| x.to_vec()),
audio_hashes: seq.audio_hashes().map(|x| x.to_vec()),
video_hashes: seq.video_hashes().map(|x| x.to_vec()),
},
);
}
}
pub fn evict_caches(&mut self) -> Result<usize> {
if self.no_prefix_cache {
return Ok(0);
}
let mut n_on_device = 0;
for cache in self.caches.values() {
let first_non_none = cache
.cache
.iter()
.find_or_first(|x| x.as_ref().is_some_and(|kv| kv.k().ok().flatten().is_some()));
let Some(Some(first_non_none)) = first_non_none else {
continue;
};
let cache_device = match first_non_none {
KvCache::Normal { k, .. } => {
k.all_data().as_ref().expect("No KV cache data").device()
}
KvCache::Rotating { k, .. } => {
k.all_data().as_ref().expect("No KV cache data").device()
}
KvCache::Shared { .. } => continue,
};
if !matches!(cache_device, Device::Cpu) {
n_on_device += 1;
}
}
let mut n_evicted = 0;
for cache in self.caches.values_mut() {
if n_on_device - n_evicted <= self.n_on_device {
break;
}
let first_non_none = cache
.cache
.iter()
.find_or_first(|x| x.as_ref().is_some_and(|kv| kv.k().ok().flatten().is_some()));
let Some(Some(first_non_none)) = first_non_none else {
continue;
};
let cache_device = match first_non_none {
KvCache::Normal { k, .. } => {
k.all_data().as_ref().expect("No KV cache data").device()
}
KvCache::Rotating { k, .. } => {
k.all_data().as_ref().expect("No KV cache data").device()
}
KvCache::Shared { .. } => continue,
};
if !matches!(cache_device, Device::Cpu) {
cache.cache.clear();
n_evicted += 1;
}
}
self.caches.retain(|_tokens, cache| !cache.cache.is_empty());
Ok(n_evicted)
}
pub fn evict_all_caches(&mut self) -> Result<usize> {
let len = self.caches.len();
self.caches.clear();
self.paged_recurrent_caches.clear();
Ok(len)
}
pub fn add_paged_recurrent_prefix(
&mut self,
key: Vec<BlockHash>,
snapshots: Vec<RecurrentStateSnapshot>,
) {
if self.no_prefix_cache
|| !self.has_paged_attention
|| key.is_empty()
|| snapshots.is_empty()
{
return;
}
let _ = self.paged_recurrent_caches.shift_remove(&key);
self.paged_recurrent_caches.insert(key, snapshots);
while self.paged_recurrent_caches.len() > self.paged_recurrent_capacity() {
let _ = self.paged_recurrent_caches.shift_remove_index(0);
}
}
pub fn get_paged_recurrent_prefix(
&mut self,
key: &[BlockHash],
) -> Option<Vec<RecurrentStateSnapshot>> {
if self.no_prefix_cache || !self.has_paged_attention || key.is_empty() {
return None;
}
let key = key.to_vec();
let snapshots = self.paged_recurrent_caches.shift_remove(&key)?;
let out = snapshots.clone();
self.paged_recurrent_caches.insert(key, snapshots);
Some(out)
}
pub fn search_for_matching_cache(
&mut self,
toks: &[u32],
image_hashes: Option<&[u64]>,
audio_hashes: Option<&[u64]>,
video_hashes: Option<&[u64]>,
) -> Result<Option<MatchingCache>> {
if self.no_prefix_cache || toks.is_empty() {
return Ok(None);
}
if self.has_paged_attention {
return Ok(None);
}
let toks = Tokens(toks.to_vec());
let mut best_match: Option<(usize, &CacheElement, usize, usize, usize)> = None;
for (k, v) in &self.caches {
let match_len = toks.shared_prefix_len(k);
if match_len == 0 {
continue;
}
let images_match_until = match image_hashes {
Some(input_hashes) => match &v.image_hashes {
Some(cached_hashes) => input_hashes
.iter()
.zip(cached_hashes)
.take_while(|(a, b)| a == b)
.count(),
None => 0,
},
None => 0,
};
let audios_match_until = match audio_hashes {
Some(input_hashes) => match &v.audio_hashes {
Some(cached_hashes) => input_hashes
.iter()
.zip(cached_hashes)
.take_while(|(a, b)| a == b)
.count(),
None => 0,
},
None => 0,
};
let cached_image_count = v.image_hashes.as_ref().map_or(0, |h| h.len());
let input_image_count = image_hashes.map_or(0, |h| h.len());
if images_match_until < input_image_count.min(cached_image_count) {
continue;
}
let cached_audio_count = v.audio_hashes.as_ref().map_or(0, |h| h.len());
let input_audio_count = audio_hashes.map_or(0, |h| h.len());
if audios_match_until < input_audio_count.min(cached_audio_count) {
continue;
}
let videos_match_until = match video_hashes {
Some(input_hashes) => match &v.video_hashes {
Some(cached_hashes) => input_hashes
.iter()
.zip(cached_hashes)
.take_while(|(a, b)| a == b)
.count(),
None => 0,
},
None => 0,
};
let cached_video_count = v.video_hashes.as_ref().map_or(0, |h| h.len());
let input_video_count = video_hashes.map_or(0, |h| h.len());
if videos_match_until < input_video_count.min(cached_video_count) {
continue;
}
if !v.can_rewind_to(match_len) {
continue;
}
if best_match
.as_ref()
.is_none_or(|(len, _, _, _, _)| match_len > *len)
{
best_match = Some((
match_len,
v,
images_match_until,
audios_match_until,
videos_match_until,
));
}
}
if let Some((
match_len,
cache_element,
images_match_until,
audios_match_until,
videos_match_until,
)) = best_match
{
let new_toks = toks.0[match_len..].to_vec();
if new_toks.is_empty() {
return Ok(None);
}
let mut cache = cache_element.clone();
let images_to_keep = if let Some(input_hashes) = image_hashes {
input_hashes.len().saturating_sub(images_match_until)
} else {
0
};
let audios_to_keep = if let Some(input_hashes) = audio_hashes {
input_hashes.len().saturating_sub(audios_match_until)
} else {
0
};
let videos_to_keep = if let Some(input_hashes) = video_hashes {
input_hashes.len().saturating_sub(videos_match_until)
} else {
0
};
for layer in cache.cache.iter_mut().flatten() {
if layer.try_set_len(match_len).is_err() {
return Ok(None);
}
}
for layer in cache.cache.iter_mut().flatten() {
layer.set_len(match_len)?;
}
return Ok(Some(MatchingCache::Normal {
normal: cache.cache,
recurrent_snapshots: cache.recurrent_snapshots,
images_to_keep,
audios_to_keep,
videos_to_keep,
toks: new_toks,
offset: match_len,
}));
}
Ok(None)
}
}
#[cfg(test)]
mod tests {
use candle_core::{DType, Device, Tensor};
use super::{CacheElement, MatchingCache, PrefixCacheManagerV2};
use crate::kv_cache::{KvCache, RotatingCache, SingleCache};
fn make_cache_tensor(len: usize) -> candle_core::Result<Tensor> {
Tensor::zeros((1, 1, len, 1), DType::F32, &Device::Cpu)
}
fn make_rotating_kv_cache(
logical_len: usize,
sliding_window: usize,
) -> candle_core::Result<KvCache> {
let src = make_cache_tensor(logical_len)?;
let mut k = RotatingCache::new(2, sliding_window, sliding_window);
let mut v = RotatingCache::new(2, sliding_window, sliding_window);
let _ = k.append(&src)?;
let _ = v.append(&src)?;
Ok(KvCache::Rotating { k, v })
}
fn make_normal_kv_cache(logical_len: usize) -> candle_core::Result<KvCache> {
let src = make_cache_tensor(logical_len)?;
let mut k = SingleCache::new(2, logical_len, logical_len);
let mut v = SingleCache::new(2, logical_len, logical_len);
k.append(&src)?;
v.append(&src)?;
Ok(KvCache::Normal { k, v })
}
#[test]
fn skips_rolled_over_rotating_candidate_that_cannot_rewind() -> candle_core::Result<()> {
let mut prefix_cacher = PrefixCacheManagerV2::new(1, false, false);
prefix_cacher.caches.insert(
vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10].into(),
CacheElement {
cache: vec![Some(make_rotating_kv_cache(10, 4)?)],
recurrent_snapshots: None,
audio_hashes: None,
image_hashes: None,
video_hashes: None,
},
);
prefix_cacher.caches.insert(
vec![1, 2, 3, 4, 5, 9].into(),
CacheElement {
cache: vec![Some(make_normal_kv_cache(6)?)],
recurrent_snapshots: None,
audio_hashes: None,
image_hashes: None,
video_hashes: None,
},
);
let hit = prefix_cacher.search_for_matching_cache(
&[1, 2, 3, 4, 5, 6, 7, 99],
None,
None,
None,
)?;
match hit {
Some(MatchingCache::Normal { toks, offset, .. }) => {
assert_eq!(offset, 5);
assert_eq!(toks, vec![6, 7, 99]);
}
None => panic!("expected a shorter valid prefix-cache hit"),
}
Ok(())
}
#[test]
fn allows_exact_extension_from_rolled_over_rotating_cache() -> candle_core::Result<()> {
let mut prefix_cacher = PrefixCacheManagerV2::new(1, false, false);
prefix_cacher.caches.insert(
vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10].into(),
CacheElement {
cache: vec![Some(make_rotating_kv_cache(10, 4)?)],
recurrent_snapshots: None,
audio_hashes: None,
image_hashes: None,
video_hashes: None,
},
);
let hit = prefix_cacher.search_for_matching_cache(
&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
None,
None,
None,
)?;
match hit {
Some(MatchingCache::Normal { toks, offset, .. }) => {
assert_eq!(offset, 10);
assert_eq!(toks, vec![11]);
}
None => panic!("expected exact-extension prefix-cache hit"),
}
Ok(())
}
}