use candle_core::{DType, Device, Tensor};
use super::{NeuralError, Result};
#[derive(Clone, Debug)]
pub struct CacheConfig {
pub max_seq_len: usize,
pub num_layers: usize,
pub num_heads: usize,
pub head_dim: usize,
pub dtype: DType,
}
impl Default for CacheConfig {
fn default() -> Self {
Self {
max_seq_len: 8192,
num_layers: 12,
num_heads: 12,
head_dim: 64,
dtype: DType::F32,
}
}
}
#[derive(Debug)]
pub struct LayerCache {
pub key: Option<Tensor>,
pub value: Option<Tensor>,
pub seq_len: usize,
}
impl LayerCache {
pub fn new() -> Self {
Self {
key: None,
value: None,
seq_len: 0,
}
}
pub fn update(&mut self, new_key: Tensor, new_value: Tensor) -> Result<(Tensor, Tensor)> {
let (key, value) = match (&self.key, &self.value) {
(Some(k), Some(v)) => {
let key = Tensor::cat(&[k, &new_key], 2)?;
let value = Tensor::cat(&[v, &new_value], 2)?;
(key, value)
}
_ => (new_key, new_value),
};
self.seq_len = key.dim(2)?;
self.key = Some(key.clone());
self.value = Some(value.clone());
Ok((key, value))
}
pub fn clear(&mut self) {
self.key = None;
self.value = None;
self.seq_len = 0;
}
pub fn is_empty(&self) -> bool {
self.key.is_none()
}
}
impl Default for LayerCache {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug)]
pub struct KvCache {
layers: Vec<LayerCache>,
config: CacheConfig,
device: Device,
}
impl KvCache {
pub fn new(config: CacheConfig, device: Device) -> Self {
let layers = (0..config.num_layers).map(|_| LayerCache::new()).collect();
Self {
layers,
config,
device,
}
}
pub fn layer(&self, layer_idx: usize) -> Option<&LayerCache> {
self.layers.get(layer_idx)
}
pub fn layer_mut(&mut self, layer_idx: usize) -> Option<&mut LayerCache> {
self.layers.get_mut(layer_idx)
}
pub fn update_layer(
&mut self,
layer_idx: usize,
key: Tensor,
value: Tensor,
) -> Result<(Tensor, Tensor)> {
let layer = self
.layers
.get_mut(layer_idx)
.ok_or_else(|| NeuralError::Inference(format!("Invalid layer index: {}", layer_idx)))?;
layer.update(key, value)
}
pub fn seq_len(&self) -> usize {
self.layers.first().map(|l| l.seq_len).unwrap_or(0)
}
pub fn clear(&mut self) {
for layer in &mut self.layers {
layer.clear();
}
}
pub fn is_empty(&self) -> bool {
self.layers.iter().all(|l| l.is_empty())
}
pub fn config(&self) -> &CacheConfig {
&self.config
}
pub fn num_layers(&self) -> usize {
self.layers.len()
}
pub fn preallocate(&mut self, batch_size: usize) -> Result<()> {
let shape = (
batch_size,
self.config.num_heads,
self.config.max_seq_len,
self.config.head_dim,
);
for layer in &mut self.layers {
if layer.key.is_none() {
layer.key = Some(Tensor::zeros(shape, self.config.dtype, &self.device)?);
layer.value = Some(Tensor::zeros(shape, self.config.dtype, &self.device)?);
}
}
Ok(())
}
}
#[derive(Debug)]
pub struct SlidingWindowCache {
inner: KvCache,
window_size: usize,
}
impl SlidingWindowCache {
pub fn new(config: CacheConfig, window_size: usize, device: Device) -> Self {
Self {
inner: KvCache::new(config, device),
window_size,
}
}
pub fn update_layer(
&mut self,
layer_idx: usize,
key: Tensor,
value: Tensor,
) -> Result<(Tensor, Tensor)> {
let result = self.inner.update_layer(layer_idx, key, value)?;
if let Some(layer) = self.inner.layer_mut(layer_idx) {
if layer.seq_len > self.window_size {
if let (Some(k), Some(v)) = (&layer.key, &layer.value) {
let start = layer.seq_len - self.window_size;
let k_windowed = k.narrow(2, start, self.window_size)?;
let v_windowed = v.narrow(2, start, self.window_size)?;
layer.key = Some(k_windowed);
layer.value = Some(v_windowed);
layer.seq_len = self.window_size;
}
}
}
Ok(result)
}
pub fn clear(&mut self) {
self.inner.clear();
}
pub fn seq_len(&self) -> usize {
self.inner.seq_len()
}
pub fn window_size(&self) -> usize {
self.window_size
}
}
#[derive(Debug)]
pub struct EmbeddingCache {
entries: dashmap::DashMap<u64, std::sync::Arc<[f32]>>,
max_entries: usize,
access_order: parking_lot::Mutex<std::collections::VecDeque<u64>>,
}
impl Default for EmbeddingCache {
fn default() -> Self {
Self::new(1000)
}
}
impl EmbeddingCache {
pub fn new(max_entries: usize) -> Self {
Self {
entries: dashmap::DashMap::with_capacity(max_entries),
max_entries,
access_order: parking_lot::Mutex::new(std::collections::VecDeque::with_capacity(
max_entries,
)),
}
}
pub fn get(&self, text: &str) -> Option<std::sync::Arc<[f32]>> {
let hash = Self::hash_text(text);
if let Some(entry) = self.entries.get(&hash) {
{
let mut order = self.access_order.lock();
if let Some(pos) = order.iter().position(|&h| h == hash) {
order.remove(pos);
}
order.push_back(hash);
}
Some(std::sync::Arc::clone(&entry))
} else {
None
}
}
pub fn insert(&self, text: &str, embedding: Vec<f32>) {
let hash = Self::hash_text(text);
let embedding: std::sync::Arc<[f32]> = embedding.into();
let should_evict =
self.entries.len() >= self.max_entries && !self.entries.contains_key(&hash);
if should_evict {
let oldest = {
let mut order = self.access_order.lock();
order.pop_front() };
if let Some(oldest_hash) = oldest {
self.entries.remove(&oldest_hash);
}
}
self.entries.insert(hash, embedding);
{
let mut order = self.access_order.lock();
if let Some(pos) = order.iter().position(|&h| h == hash) {
order.remove(pos);
}
order.push_back(hash);
}
}
pub fn clear(&self) {
self.entries.clear();
self.access_order.lock().clear();
}
pub fn len(&self) -> usize {
self.entries.len()
}
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
fn hash_text(text: &str) -> u64 {
crate::util::hash::safe_hash(text.as_bytes())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_layer_cache() {
let cache = LayerCache::new();
assert!(cache.is_empty());
}
#[test]
fn test_embedding_cache() {
let cache = EmbeddingCache::new(2);
cache.insert("hello", vec![1.0, 2.0, 3.0]);
cache.insert("world", vec![4.0, 5.0, 6.0]);
assert_eq!(cache.len(), 2);
assert_eq!(
cache.get("hello").as_deref(),
Some([1.0, 2.0, 3.0].as_slice())
);
cache.insert("test", vec![7.0, 8.0, 9.0]);
assert_eq!(cache.len(), 2);
assert!(cache.get("world").is_none());
assert!(cache.get("hello").is_some());
assert!(cache.get("test").is_some());
}
#[test]
fn test_embedding_cache_lru() {
let cache = EmbeddingCache::new(3);
cache.insert("a", vec![1.0]);
cache.insert("b", vec![2.0]);
cache.insert("c", vec![3.0]);
let _ = cache.get("a");
cache.insert("d", vec![4.0]);
assert!(cache.get("a").is_some());
assert!(cache.get("b").is_none()); assert!(cache.get("c").is_some());
assert!(cache.get("d").is_some());
}
#[test]
fn test_embedding_cache_concurrent() {
use std::sync::Arc;
use std::thread;
let cache = Arc::new(EmbeddingCache::new(100));
let handles: Vec<_> = (0..4)
.map(|i| {
let cache = Arc::clone(&cache);
thread::spawn(move || {
for j in 0..25 {
let key = format!("key_{}_{}", i, j);
cache.insert(&key, vec![i as f32, j as f32]);
}
})
})
.collect();
for h in handles {
h.join().expect("thread panicked");
}
assert!(cache.len() <= 100);
assert!(cache.len() > 0);
}
}