use std::{
collections::{HashMap, VecDeque},
hash::Hash,
};
use crate::{
error::{Error, MissingKeyPayload, Result},
lm::cache::{KvCache, can_trim_prompt_cache, trim_prompt_cache},
};
pub const TOKEN_BUFFER_STEP: usize = 256;
#[derive(Debug, Clone, Default)]
pub struct TokenBuffer {
buffer: Vec<i32>,
}
impl TokenBuffer {
pub fn new(tokens: &[i32]) -> Self {
Self {
buffer: tokens.to_vec(),
}
}
pub fn update_and_fetch(&mut self, tokens: &[i32]) -> &[i32] {
self.buffer.extend_from_slice(tokens);
&self.buffer
}
pub fn tokens(&self) -> &[i32] {
&self.buffer
}
pub fn len(&self) -> usize {
self.buffer.len()
}
pub fn is_empty(&self) -> bool {
self.buffer.is_empty()
}
}
struct ChildMap<V> {
entries: Vec<(i32, TrieNode<V>)>,
}
impl<V> ChildMap<V> {
fn new() -> Self {
Self {
entries: Vec::new(),
}
}
fn is_empty(&self) -> bool {
self.entries.is_empty()
}
fn get(&self, tok: i32) -> Option<&TrieNode<V>> {
self
.entries
.iter()
.find_map(|(k, v)| (*k == tok).then_some(v))
}
fn get_mut(&mut self, tok: i32) -> Option<&mut TrieNode<V>> {
self
.entries
.iter_mut()
.find_map(|(k, v)| (*k == tok).then_some(v))
}
fn get_or_insert(&mut self, tok: i32) -> &mut TrieNode<V> {
if let Some(i) = self.entries.iter().position(|(k, _)| *k == tok) {
&mut self.entries[i].1
} else {
self.entries.push((tok, TrieNode::new()));
&mut self.entries.last_mut().expect("just pushed").1
}
}
fn remove(&mut self, tok: i32) {
if let Some(i) = self.entries.iter().position(|(k, _)| *k == tok) {
self.entries.remove(i);
}
}
fn iter_insertion_order(&self) -> impl Iterator<Item = (i32, &TrieNode<V>)> {
self.entries.iter().map(|(k, v)| (*k, v))
}
}
struct TrieNode<V> {
children: ChildMap<V>,
value: Option<V>,
}
impl<V> TrieNode<V> {
fn new() -> Self {
Self {
children: ChildMap::new(),
value: None,
}
}
fn is_empty(&self) -> bool {
self.children.is_empty() && self.value.is_none()
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PromptTrieResult {
pub exact: Option<Vec<i32>>,
pub shorter: Option<Vec<i32>>,
pub longer: Option<Vec<i32>>,
pub common_prefix: usize,
}
pub struct PromptTrie<M, V = ()> {
trie: HashMap<M, TrieNode<V>>,
}
impl<M: Eq + Hash + Clone, V> Default for PromptTrie<M, V> {
fn default() -> Self {
Self::new()
}
}
impl<M: Eq + Hash + Clone, V> PromptTrie<M, V> {
pub fn new() -> Self {
Self {
trie: HashMap::new(),
}
}
pub fn add(&mut self, model: &M, tokens: &[i32], value: V) -> Option<V> {
let root = self.trie.entry(model.clone()).or_insert_with(TrieNode::new);
let mut current = root;
for &tok in tokens {
current = current.children.get_or_insert(tok);
}
current.value.replace(value)
}
pub fn get(&self, model: &M, tokens: &[i32]) -> Option<&V> {
let mut current = self.trie.get(model)?;
for &tok in tokens {
current = current.children.get(tok)?;
}
current.value.as_ref()
}
pub fn pop(&mut self, model: &M, tokens: &[i32]) -> Option<V> {
{
let mut current = self.trie.get(model)?;
for &tok in tokens {
current = current.children.get(tok)?;
}
current.value.as_ref()?;
}
let root = self.trie.get_mut(model)?;
fn take_and_prune<V>(node: &mut TrieNode<V>, toks: &[i32]) -> Option<V> {
match toks.split_first() {
None => node.value.take(),
Some((&tok, rest)) => {
let child = node.children.get_mut(tok)?;
let v = take_and_prune(child, rest)?;
if child.is_empty() {
node.children.remove(tok);
}
Some(v)
}
}
}
take_and_prune(root, tokens)
}
pub fn pop_prefixes(&mut self, model: &M, tokens: &[i32]) -> Vec<(usize, V)> {
let mut values = Vec::new();
let Some(mut current) = self.trie.get_mut(model) else {
return values;
};
for (i, &tok) in tokens.iter().enumerate() {
if let Some(v) = current.value.take() {
values.push((i, v));
}
match current.children.get_mut(tok) {
Some(next) => current = next,
None => break,
}
}
values
}
pub fn search(&self, model: &M, tokens: &[i32]) -> PromptTrieResult {
let Some(root) = self.trie.get(model) else {
return PromptTrieResult {
exact: None,
shorter: None,
longer: None,
common_prefix: 0,
};
};
let mut current = root;
if tokens.is_empty() && current.value.is_some() {
return PromptTrieResult {
exact: Some(Vec::new()),
shorter: None,
longer: None,
common_prefix: 0,
};
}
let mut last_index: Option<usize> = None;
let mut index: usize = 0;
while index < tokens.len() {
match current.children.get(tokens[index]) {
Some(next) => {
current = next;
if current.value.is_some() {
last_index = Some(index);
}
index += 1;
}
None => break,
}
}
if !tokens.is_empty()
&& let Some(li) = last_index
&& li == tokens.len() - 1
{
return PromptTrieResult {
exact: Some(tokens.to_vec()),
shorter: None,
longer: None,
common_prefix: 0,
};
}
let shorter = match last_index {
Some(li) if li > 0 => Some(tokens[..li + 1].to_vec()),
_ => None,
};
let common_prefix = index;
let mut longer = None;
if index > 0 {
let mut best: Option<Vec<i32>> = None;
let mut stack: Vec<(&TrieNode<V>, Vec<i32>)> = vec![(current, Vec::new())];
while let Some((node, extra)) = stack.pop() {
if node.value.is_some() {
if best.as_ref().is_none_or(|b| extra.len() < b.len()) {
best = Some(extra);
}
} else if best.as_ref().is_none_or(|b| extra.len() < b.len()) {
for (tok, child) in node.children.iter_insertion_order() {
let mut e = extra.clone();
e.push(tok);
stack.push((child, e));
}
}
}
if let Some(best) = best {
let mut l = tokens[..index].to_vec();
l.extend(best);
longer = Some(l);
}
}
PromptTrieResult {
exact: None,
shorter,
longer,
common_prefix,
}
}
}
struct CacheEntry {
prompt_cache: Vec<Box<dyn KvCache>>,
nbytes: usize,
cache_type: String,
}
struct CacheOrder<M> {
ordering: Vec<String>,
lrus: HashMap<String, VecDeque<(M, Vec<i32>)>>,
}
impl<M: Eq + Clone> CacheOrder<M> {
fn new() -> Self {
let ordering: Vec<String> = ["assistant", "user", "system"]
.iter()
.map(|s| s.to_string())
.collect();
let lrus = ordering
.iter()
.map(|k| (k.clone(), VecDeque::new()))
.collect();
Self { ordering, lrus }
}
fn len(&self) -> usize {
self.lrus.values().map(VecDeque::len).sum()
}
fn push(&mut self, model: &M, tokens: &[i32], cache_type: &str) {
if let Some(d) = self.lrus.get_mut(cache_type) {
d.push_back((model.clone(), tokens.to_vec()));
}
}
fn remove(&mut self, model: &M, tokens: &[i32]) {
for ct in &self.ordering {
if let Some(d) = self.lrus.get_mut(ct)
&& let Some(pos) = d
.iter()
.position(|(m, t)| m == model && t.as_slice() == tokens)
{
d.remove(pos);
break;
}
}
}
fn pop(&mut self) -> Option<(M, Vec<i32>)> {
let mut i = 0;
while i + 1 < self.ordering.len() {
let len_a = self.lrus.get(&self.ordering[i]).map_or(0, VecDeque::len);
let len_b = self
.lrus
.get(&self.ordering[i + 1])
.map_or(0, VecDeque::len);
if len_a > 0 && len_a >= len_b {
return self
.lrus
.get_mut(&self.ordering[i])
.and_then(VecDeque::pop_front);
}
i += 1;
}
self
.ordering
.last()
.and_then(|k| self.lrus.get_mut(k))
.and_then(VecDeque::pop_front)
}
fn type_len(&self, cache_type: &str) -> usize {
self.lrus.get(cache_type).map_or(0, VecDeque::len)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct CacheTypeStats {
pub n_sequences: usize,
pub n_bytes: usize,
}
pub const LRU_UNBOUNDED: usize = 1 << 63;
pub struct LruPromptCache<M: Eq + Hash + Clone> {
max_size: usize,
max_bytes: usize,
trie: PromptTrie<M, CacheEntry>,
lru: CacheOrder<M>,
n_bytes: usize,
n_bytes_by_type: HashMap<String, usize>,
}
impl<M: Eq + Hash + Clone> LruPromptCache<M> {
pub fn new(max_size: usize, max_bytes: usize) -> Self {
let lru = CacheOrder::new();
let n_bytes_by_type = lru.ordering.iter().map(|k| (k.clone(), 0usize)).collect();
Self {
max_size,
max_bytes,
trie: PromptTrie::new(),
lru,
n_bytes: 0,
n_bytes_by_type,
}
}
pub fn len(&self) -> usize {
self.lru.len()
}
pub fn is_empty(&self) -> bool {
self.lru.len() == 0
}
pub fn nbytes(&self) -> usize {
self.n_bytes
}
#[allow(clippy::type_complexity)]
pub fn fetch_nearest_cache(
&self,
model: &M,
tokens: &[i32],
) -> Result<(Option<Vec<Box<dyn KvCache>>>, Vec<i32>)> {
let result = self.trie.search(model, tokens);
if let Some(exact) = &result.exact
&& let Some(entry) = self.trie.get(model, exact)
{
let copy = copy_prompt_cache(&entry.prompt_cache)?;
return Ok((Some(copy), Vec::new()));
}
let short_length = result.shorter.as_ref().map_or(0, Vec::len);
if let Some(longer) = &result.longer
&& result.common_prefix > short_length
&& let Some(entry) = self.trie.get(model, longer)
{
if can_trim_prompt_cache(&entry.prompt_cache) {
let mut cache = copy_prompt_cache(&entry.prompt_cache)?;
let prefix = tokens.len().saturating_sub(1).min(result.common_prefix);
let num_to_trim = longer.len().saturating_sub(prefix);
trim_prompt_cache(&mut cache, num_to_trim)?;
return Ok((Some(cache), tokens[prefix..].to_vec()));
}
}
if short_length > 0
&& let Some(shorter) = &result.shorter
&& let Some(entry) = self.trie.get(model, shorter)
{
let copy = copy_prompt_cache(&entry.prompt_cache)?;
return Ok((Some(copy), tokens[short_length..].to_vec()));
}
Ok((None, tokens.to_vec()))
}
pub fn insert_cache(
&mut self,
model: &M,
tokens: &[i32],
prompt_cache: Vec<Box<dyn KvCache>>,
cache_type: &str,
) -> Result<()> {
if !self.lru.ordering.iter().any(|k| k == cache_type) {
return Err(Error::MissingKey(MissingKeyPayload::new(
"LruPromptCache::add: cache_type (must be one of the configured CacheOrder buckets)",
cache_type,
)));
}
let trimmable = can_trim_prompt_cache(&prompt_cache);
let entry_nbytes: usize = prompt_cache.iter().map(|c| c.nbytes()).sum();
let entry = CacheEntry {
prompt_cache,
nbytes: entry_nbytes,
cache_type: cache_type.to_string(),
};
self.n_bytes += entry_nbytes;
if let Some(b) = self.n_bytes_by_type.get_mut(cache_type) {
*b += entry_nbytes;
}
let prev = self.trie.add(model, tokens, entry);
if let Some(prev) = prev {
self.n_bytes -= prev.nbytes;
if let Some(b) = self.n_bytes_by_type.get_mut(&prev.cache_type) {
*b -= prev.nbytes;
}
self.lru.remove(model, tokens);
}
self.lru.push(model, tokens, cache_type);
if trimmable {
for (prefix_len, prefix_entry) in self.trie.pop_prefixes(model, tokens) {
self.n_bytes -= prefix_entry.nbytes;
if let Some(b) = self.n_bytes_by_type.get_mut(&prefix_entry.cache_type) {
*b -= prefix_entry.nbytes;
}
self.lru.remove(model, &tokens[..prefix_len]);
}
}
if self.lru.len() > self.max_size
&& let Some((m, t)) = self.lru.pop()
&& let Some(e) = self.trie.pop(&m, &t)
{
self.n_bytes -= e.nbytes;
if let Some(b) = self.n_bytes_by_type.get_mut(&e.cache_type) {
*b -= e.nbytes;
}
}
while self.n_bytes > self.max_bytes {
let Some((m, t)) = self.lru.pop() else { break };
let Some(e) = self.trie.pop(&m, &t) else {
break;
};
self.n_bytes -= e.nbytes;
if let Some(b) = self.n_bytes_by_type.get_mut(&e.cache_type) {
*b -= e.nbytes;
}
}
Ok(())
}
pub fn insert_cache_assistant(
&mut self,
model: &M,
tokens: &[i32],
prompt_cache: Vec<Box<dyn KvCache>>,
) -> Result<()> {
self.insert_cache(model, tokens, prompt_cache, "assistant")
}
pub fn trim_to(&mut self, n_sequences: Option<usize>, n_bytes: Option<usize>) {
let n_sequences = n_sequences.unwrap_or(LRU_UNBOUNDED);
let n_bytes = n_bytes.unwrap_or(LRU_UNBOUNDED);
while self.lru.len() > n_sequences {
let Some((m, t)) = self.lru.pop() else { break };
let Some(e) = self.trie.pop(&m, &t) else {
break;
};
self.n_bytes -= e.nbytes;
if let Some(b) = self.n_bytes_by_type.get_mut(&e.cache_type) {
*b -= e.nbytes;
}
}
while self.n_bytes > n_bytes {
let Some((m, t)) = self.lru.pop() else { break };
let Some(e) = self.trie.pop(&m, &t) else {
break;
};
self.n_bytes -= e.nbytes;
if let Some(b) = self.n_bytes_by_type.get_mut(&e.cache_type) {
*b -= e.nbytes;
}
}
}
pub fn stats_by_type(&self) -> HashMap<String, CacheTypeStats> {
let mut result = HashMap::new();
for ct in &self.lru.ordering {
result.insert(
ct.clone(),
CacheTypeStats {
n_sequences: self.lru.type_len(ct),
n_bytes: self.n_bytes_by_type.get(ct).copied().unwrap_or(0),
},
);
}
result
}
}
fn copy_prompt_cache(cache: &[Box<dyn KvCache>]) -> Result<Vec<Box<dyn KvCache>>> {
cache.iter().map(|c| c.copy()).collect()
}