use std::collections::HashMap;
use std::hash::{Hash, Hasher};
use crate::tensor::Tensor;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct PrefixId(pub u64);
impl PrefixId {
pub fn from_tokens(tokens: &[u32]) -> Self {
use std::collections::hash_map::DefaultHasher;
let mut hasher = DefaultHasher::new();
tokens.hash(&mut hasher);
PrefixId(hasher.finish())
}
}
#[derive(Debug, Clone)]
pub struct CachedPrefix {
pub tokens: Vec<u32>,
pub k_cache: Vec<Tensor>,
pub v_cache: Vec<Tensor>,
pub seq_len: usize,
pub ref_count: usize,
pub last_access: std::time::Instant,
}
impl CachedPrefix {
pub fn new(tokens: Vec<u32>, k_cache: Vec<Tensor>, v_cache: Vec<Tensor>) -> Self {
let seq_len = tokens.len();
Self {
tokens,
k_cache,
v_cache,
seq_len,
ref_count: 0,
last_access: std::time::Instant::now(),
}
}
pub fn memory_size(&self) -> usize {
let k_size: usize = self.k_cache.iter().map(|t| t.data().len()).sum();
let v_size: usize = self.v_cache.iter().map(|t| t.data().len()).sum();
k_size + v_size + self.tokens.len() * 4
}
}
#[derive(Debug, Clone)]
pub struct PromptCacheConfig {
pub max_entries: usize,
pub max_memory: usize,
pub min_prefix_len: usize,
pub cache_system_prompts: bool,
}
impl Default for PromptCacheConfig {
fn default() -> Self {
Self {
max_entries: 100,
max_memory: 1024 * 1024 * 1024, min_prefix_len: 32,
cache_system_prompts: true,
}
}
}
pub struct PromptCache {
config: PromptCacheConfig,
entries: HashMap<PrefixId, CachedPrefix>,
memory_used: usize,
}
impl PromptCache {
pub fn new(config: PromptCacheConfig) -> Self {
Self {
config,
entries: HashMap::new(),
memory_used: 0,
}
}
pub fn cache_prefix(
&mut self,
tokens: &[u32],
k_cache: Vec<Tensor>,
v_cache: Vec<Tensor>,
) -> PrefixId {
let id = PrefixId::from_tokens(tokens);
if self.entries.contains_key(&id) {
if let Some(entry) = self.entries.get_mut(&id) {
entry.ref_count += 1;
entry.last_access = std::time::Instant::now();
}
return id;
}
if tokens.len() < self.config.min_prefix_len {
return id;
}
let prefix = CachedPrefix::new(tokens.to_vec(), k_cache, v_cache);
let size = prefix.memory_size();
while self.memory_used + size > self.config.max_memory
|| self.entries.len() >= self.config.max_entries
{
if !self.evict_lru() {
break;
}
}
self.memory_used += size;
self.entries.insert(id.clone(), prefix);
id
}
pub fn get_prefix(&mut self, id: &PrefixId) -> Option<&CachedPrefix> {
if let Some(entry) = self.entries.get_mut(id) {
entry.ref_count += 1;
entry.last_access = std::time::Instant::now();
Some(entry)
} else {
None
}
}
pub fn find_matching_prefix(&mut self, tokens: &[u32]) -> Option<(PrefixId, usize)> {
let mut best_match: Option<(PrefixId, usize)> = None;
for (id, entry) in &self.entries {
if tokens.len() >= entry.tokens.len()
&& tokens[..entry.tokens.len()] == entry.tokens[..]
{
let match_len = entry.tokens.len();
if best_match.is_none() || match_len > best_match.as_ref().unwrap().1 {
best_match = Some((id.clone(), match_len));
}
}
}
if let Some((ref id, _)) = best_match
&& let Some(entry) = self.entries.get_mut(id)
{
entry.last_access = std::time::Instant::now();
entry.ref_count += 1;
}
best_match
}
pub fn remove_prefix(&mut self, id: &PrefixId) {
if let Some(entry) = self.entries.remove(id) {
self.memory_used = self.memory_used.saturating_sub(entry.memory_size());
}
}
pub fn clear(&mut self) {
self.entries.clear();
self.memory_used = 0;
}
pub fn stats(&self) -> PromptCacheStats {
PromptCacheStats {
num_entries: self.entries.len(),
memory_used: self.memory_used,
total_tokens_cached: self.entries.values().map(|e| e.seq_len).sum(),
}
}
fn evict_lru(&mut self) -> bool {
let lru_id = self
.entries
.iter()
.filter(|(_, e)| e.ref_count == 0)
.min_by_key(|(_, e)| e.last_access)
.map(|(id, _)| id.clone());
if let Some(id) = lru_id {
self.remove_prefix(&id);
true
} else {
false
}
}
pub fn release_prefix(&mut self, id: &PrefixId) {
if let Some(entry) = self.entries.get_mut(id) {
entry.ref_count = entry.ref_count.saturating_sub(1);
}
}
}
#[derive(Debug, Clone)]
pub struct PromptCacheStats {
pub num_entries: usize,
pub memory_used: usize,
pub total_tokens_cached: usize,
}
pub struct PrefixSharing {
cache: PromptCache,
active_prefix: Option<PrefixId>,
}
impl PrefixSharing {
pub fn new(config: PromptCacheConfig) -> Self {
Self {
cache: PromptCache::new(config),
active_prefix: None,
}
}
pub fn try_restore(
&mut self,
tokens: &[u32],
k_cache: &mut [Tensor],
v_cache: &mut [Tensor],
) -> usize {
let (id, match_len) = match self.cache.find_matching_prefix(tokens) {
Some(m) => m,
None => return 0,
};
let prefix = match self.cache.get_prefix(&id) {
Some(p) => p,
None => return 0,
};
for (layer_idx, (cached_k, cached_v)) in
prefix.k_cache.iter().zip(prefix.v_cache.iter()).enumerate()
{
if layer_idx < k_cache.len() {
let k_src = cached_k.data();
let v_src = cached_v.data();
if let Some(k_dst) = k_cache[layer_idx].data_mut() {
let copy_len = k_src.len().min(k_dst.len());
k_dst[..copy_len].copy_from_slice(&k_src[..copy_len]);
}
if let Some(v_dst) = v_cache[layer_idx].data_mut() {
let copy_len = v_src.len().min(v_dst.len());
v_dst[..copy_len].copy_from_slice(&v_src[..copy_len]);
}
}
}
self.active_prefix = Some(id);
match_len
}
pub fn save_prefix(
&mut self,
tokens: &[u32],
k_cache: &[Tensor],
v_cache: &[Tensor],
) -> PrefixId {
let k_cloned: Vec<Tensor> = k_cache.to_vec();
let v_cloned: Vec<Tensor> = v_cache.to_vec();
let id = self.cache.cache_prefix(tokens, k_cloned, v_cloned);
self.active_prefix = Some(id.clone());
id
}
pub fn release_active(&mut self) {
if let Some(id) = self.active_prefix.take() {
self.cache.release_prefix(&id);
}
}
pub fn stats(&self) -> PromptCacheStats {
self.cache.stats()
}
pub fn clear(&mut self) {
self.active_prefix = None;
self.cache.clear();
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tensor::DType;
#[test]
fn test_prefix_id() {
let tokens1 = vec![1, 2, 3, 4];
let tokens2 = vec![1, 2, 3, 4];
let tokens3 = vec![1, 2, 3, 5];
let id1 = PrefixId::from_tokens(&tokens1);
let id2 = PrefixId::from_tokens(&tokens2);
let id3 = PrefixId::from_tokens(&tokens3);
assert_eq!(id1, id2);
assert_ne!(id1, id3);
}
#[test]
fn test_prompt_cache() {
let config = PromptCacheConfig {
min_prefix_len: 2,
..Default::default()
};
let mut cache = PromptCache::new(config);
let tokens = vec![1, 2, 3, 4, 5];
let k = vec![Tensor::zeros(vec![4, 4], DType::F32)];
let v = vec![Tensor::zeros(vec![4, 4], DType::F32)];
let id = cache.cache_prefix(&tokens, k, v);
assert!(cache.get_prefix(&id).is_some());
assert_eq!(cache.stats().num_entries, 1);
}
#[test]
fn test_find_matching_prefix() {
let config = PromptCacheConfig {
min_prefix_len: 2,
..Default::default()
};
let mut cache = PromptCache::new(config);
let prefix = vec![1, 2, 3];
let k = vec![Tensor::zeros(vec![4, 4], DType::F32)];
let v = vec![Tensor::zeros(vec![4, 4], DType::F32)];
cache.cache_prefix(&prefix, k, v);
let query = vec![1, 2, 3, 4, 5];
let result = cache.find_matching_prefix(&query);
assert!(result.is_some());
assert_eq!(result.unwrap().1, 3);
let query2 = vec![1, 2, 4, 5];
let result2 = cache.find_matching_prefix(&query2);
assert!(result2.is_none());
}
#[test]
fn test_cache_eviction() {
let config = PromptCacheConfig {
max_entries: 2,
min_prefix_len: 1,
..Default::default()
};
let mut cache = PromptCache::new(config);
for i in 0..3 {
let tokens = vec![i];
let k = vec![Tensor::zeros(vec![4, 4], DType::F32)];
let v = vec![Tensor::zeros(vec![4, 4], DType::F32)];
cache.cache_prefix(&tokens, k, v);
}
assert!(cache.stats().num_entries <= 2);
}
}