use std::collections::BTreeMap;
use std::sync::Arc;
use parking_lot::Mutex;
use crate::token::LlamaToken;
#[derive(Debug, Clone)]
pub struct CacheEntry {
pub state: Vec<u8>,
pub n_past: i32,
}
pub trait Cache: Send + Sync {
fn lookup(&self, tokens: &[LlamaToken]) -> Option<CacheEntry>;
fn store(&self, tokens: &[LlamaToken], entry: CacheEntry);
fn len(&self) -> usize;
fn is_empty(&self) -> bool {
self.len() == 0
}
}
#[derive(Debug, Default, Clone)]
pub struct RamCache {
inner: Arc<Mutex<BTreeMap<Vec<LlamaToken>, CacheEntry>>>,
}
impl RamCache {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn clear(&self) {
self.inner.lock().clear();
}
}
impl Cache for RamCache {
fn lookup(&self, tokens: &[LlamaToken]) -> Option<CacheEntry> {
let mut best_len = 0_usize;
let mut best_entry: Option<CacheEntry> = None;
let g = self.inner.lock();
for (key, val) in g.iter() {
if key.len() > tokens.len() || key.len() <= best_len {
continue;
}
if tokens.starts_with(key) {
best_len = key.len();
best_entry = Some(val.clone());
}
}
best_entry
}
fn store(&self, tokens: &[LlamaToken], entry: CacheEntry) {
self.inner.lock().insert(tokens.to_vec(), entry);
}
fn len(&self) -> usize {
self.inner.lock().len()
}
}
#[cfg(feature = "disk-cache")]
pub struct DiskCache {
db: sled::Db,
tree: sled::Tree,
}
#[cfg(feature = "disk-cache")]
impl std::fmt::Debug for DiskCache {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DiskCache").finish()
}
}
#[cfg(feature = "disk-cache")]
impl DiskCache {
pub fn open(path: impl AsRef<std::path::Path>) -> Result<Self, sled::Error> {
let db = sled::open(path.as_ref())?;
let tree = db.open_tree("llama_crab.cache")?;
Ok(Self { db, tree })
}
pub fn ephemeral() -> Result<Self, sled::Error> {
let db = sled::Config::new().temporary(true).open()?;
let tree = db.open_tree("llama_crab.cache")?;
Ok(Self { db, tree })
}
pub fn flush(&self) -> Result<(), sled::Error> {
let _ = self.tree.flush();
Ok(())
}
}
#[cfg(feature = "disk-cache")]
impl Cache for DiskCache {
fn lookup(&self, tokens: &[crate::token::LlamaToken]) -> Option<CacheEntry> {
let mut best_len = 0_usize;
let mut best_key: Option<Vec<u8>> = None;
for kv in self.tree.iter() {
if let Ok((k, _)) = kv {
if k.len() % std::mem::size_of::<i32>() != 0 {
continue;
}
let key_tokens: Vec<LlamaToken> = k
.chunks_exact(4)
.map(|c| i32::from_le_bytes([c[0], c[1], c[2], c[3]]))
.map(LlamaToken)
.collect();
if key_tokens.len() > tokens.len() || key_tokens.len() <= best_len {
continue;
}
if tokens.starts_with(&key_tokens) {
best_len = key_tokens.len();
best_key = Some(k.to_vec());
}
}
}
let key = best_key?;
self.tree
.get(&key)
.ok()
.flatten()
.and_then(|v| decode_entry(&v))
}
fn store(&self, tokens: &[crate::token::LlamaToken], entry: CacheEntry) {
let key = encode_key(tokens);
let val = encode_entry(&entry);
let _ = self.tree.insert(key, val);
}
fn len(&self) -> usize {
self.tree.len()
}
}
#[cfg(feature = "disk-cache")]
fn encode_key(tokens: &[crate::token::LlamaToken]) -> Vec<u8> {
let mut out = Vec::with_capacity(tokens.len() * 4);
for t in tokens {
out.extend_from_slice(&t.0.to_le_bytes());
}
out
}
#[cfg(feature = "disk-cache")]
fn encode_entry(entry: &CacheEntry) -> Vec<u8> {
let mut out = Vec::with_capacity(8 + entry.state.len());
out.extend_from_slice(&entry.n_past.to_le_bytes());
out.extend_from_slice(&(entry.state.len() as u64).to_le_bytes());
out.extend_from_slice(&entry.state);
out
}
#[cfg(feature = "disk-cache")]
fn decode_entry(bytes: &[u8]) -> Option<CacheEntry> {
if bytes.len() < 16 {
return None;
}
let n_past = i32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]);
let state_len = u64::from_le_bytes([
bytes[4], bytes[5], bytes[6], bytes[7], bytes[8], bytes[9], bytes[10], bytes[11],
]) as usize;
if bytes.len() < 12 + state_len {
return None;
}
Some(CacheEntry {
n_past,
state: bytes[12..12 + state_len].to_vec(),
})
}
#[cfg(feature = "disk-cache")]
impl Drop for DiskCache {
fn drop(&mut self) {
let _ = self.db.flush();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn ram_cache_longest_prefix() {
let c = RamCache::new();
let toks_a: Vec<LlamaToken> = (0..10).map(LlamaToken).collect();
let toks_b: Vec<LlamaToken> = (0..20).map(LlamaToken).collect();
c.store(&toks_a, CacheEntry { state: vec![1, 2, 3], n_past: 10 });
c.store(&toks_b, CacheEntry { state: vec![9, 9, 9], n_past: 20 });
let query: Vec<LlamaToken> = (0..20).map(LlamaToken).collect();
let hit = c.lookup(&query).unwrap();
assert_eq!(hit.n_past, 20);
}
#[test]
fn ram_cache_partial_match() {
let c = RamCache::new();
let stored: Vec<LlamaToken> = (0..10).map(LlamaToken).collect();
c.store(&stored, CacheEntry { state: vec![], n_past: 10 });
let query: Vec<LlamaToken> = (0..20).map(LlamaToken).collect();
let hit = c.lookup(&query).unwrap();
assert_eq!(hit.n_past, 10);
}
#[test]
fn ram_cache_no_match() {
let c = RamCache::new();
let stored: Vec<LlamaToken> = vec![LlamaToken(0), LlamaToken(1), LlamaToken(2)];
c.store(&stored, CacheEntry { state: vec![], n_past: 3 });
let query: Vec<LlamaToken> = vec![LlamaToken(99), LlamaToken(98)];
assert!(c.lookup(&query).is_none());
}
#[cfg(feature = "disk-cache")]
#[test]
fn disk_cache_roundtrip() {
let dir = tempfile::tempdir().unwrap();
let c = DiskCache::open(dir.path().join("cache")).unwrap();
let tokens: Vec<LlamaToken> = (0..8).map(LlamaToken).collect();
c.store(
&tokens,
CacheEntry {
state: vec![1, 2, 3, 4, 5],
n_past: 8,
},
);
c.flush().unwrap();
let hit = c.lookup(&tokens).unwrap();
assert_eq!(hit.n_past, 8);
assert_eq!(hit.state, vec![1, 2, 3, 4, 5]);
}
}