use dashmap::DashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use tiktoken_rs::ChatCompletionRequestMessage;
use tiktoken_rs::num_tokens_from_messages;
pub struct SemanticCache {
store: Arc<DashMap<usize, CacheEntry>>,
config: CacheConfig,
exact_store: Arc<DashMap<u64, CacheEntry>>,
}
#[derive(Debug)]
struct CacheEntry {
response: String,
tokens: usize,
hit_count: AtomicU64,
last_access: AtomicU64,
created_at: std::time::SystemTime,
}
impl Clone for CacheEntry {
fn clone(&self) -> Self {
CacheEntry {
response: self.response.clone(),
tokens: self.tokens,
hit_count: AtomicU64::new(self.hit_count.load(Ordering::Relaxed)),
last_access: AtomicU64::new(self.last_access.load(Ordering::Relaxed)),
created_at: self.created_at,
}
}
}
#[derive(Debug, Clone)]
struct CacheConfig {
max_size: usize,
similarity_threshold: f32, ttl_seconds: u64, }
impl Default for CacheConfig {
fn default() -> Self {
Self {
max_size: 10000,
similarity_threshold: 0.90, ttl_seconds: 3600, }
}
}
impl SemanticCache {
pub fn new(config: CacheConfig) -> Result<Self, Box<dyn std::error::Error>> {
Ok(Self {
store: Arc::new(DashMap::new()),
exact_store: Arc::new(DashMap::new()),
config,
})
}
pub async fn get(&self, prompt: &str) -> Option<String> {
let exact_hash = self.calculate_hash(prompt);
if let Some(entry) = self.exact_store.get(&exact_hash) {
entry.hit_count.fetch_add(1, Ordering::Relaxed);
entry.last_access.store(
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs(),
Ordering::Relaxed,
);
if self.is_expired(&entry.created_at) {
self.exact_store.remove(&exact_hash);
return None;
}
return Some(entry.response.clone());
}
let similarity_result = self.find_semantic_match(prompt).await;
if let Some(response) = similarity_result {
Some(response)
} else {
None
}
}
async fn find_semantic_match(&self, prompt: &str) -> Option<String> {
let prompt_lower = prompt.to_lowercase();
let prompt_tokens: Vec<&str> = prompt_lower.split_whitespace().collect();
for entry in self.store.iter() {
if self.is_expired(&entry.created_at) {
continue;
}
let cached_response_lower = entry.response.to_lowercase();
let cached_tokens: Vec<&str> = cached_response_lower.split_whitespace().collect();
let overlap: usize = prompt_tokens
.iter()
.filter(|token| cached_tokens.contains(token))
.count();
let total_tokens = prompt_tokens.len().max(cached_tokens.len());
let similarity = if total_tokens > 0 {
overlap as f32 / total_tokens as f32
} else {
0.0
};
if similarity >= self.config.similarity_threshold {
entry.hit_count.fetch_add(1, Ordering::Relaxed);
entry.last_access.store(
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs(),
Ordering::Relaxed,
);
return Some(entry.response.clone());
}
}
None
}
pub async fn put(&self, prompt: String, response: String) {
if self.store.len() >= self.config.max_size {
self.evict_lru().await;
}
let tokens = self.count_tokens(&response);
let created_at = std::time::SystemTime::now();
let prompt_hash = self.calculate_hash(&prompt);
let entry = CacheEntry {
response: response.clone(),
tokens,
hit_count: AtomicU64::new(1),
last_access: AtomicU64::new(
created_at
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs(),
),
created_at,
};
self.exact_store.insert(prompt_hash, entry.clone());
let idx = self.store.len();
self.store.insert(idx, entry);
}
async fn evict_lru(&self) {
let lru_entry = self
.store
.iter()
.filter(|entry| !self.is_expired(&entry.created_at))
.min_by_key(|entry| entry.last_access.load(Ordering::Relaxed));
if let Some(entry) = lru_entry {
self.store.remove(&entry.key());
}
}
fn calculate_hash(&self, s: &str) -> u64 {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
s.hash(&mut hasher);
hasher.finish()
}
fn count_tokens(&self, text: &str) -> usize {
match num_tokens_from_messages(
"gpt-3.5-turbo",
&[ChatCompletionRequestMessage {
role: "user".to_string(),
content: Some(text.to_string()),
name: None,
function_call: None,
}],
) {
Ok(count) => count,
Err(_) => {
text.chars().count() / 4 }
}
}
fn is_expired(&self, created_at: &std::time::SystemTime) -> bool {
if self.config.ttl_seconds == 0 {
return false; }
match created_at.elapsed() {
Ok(duration) => duration.as_secs() > self.config.ttl_seconds,
Err(_) => true, }
}
}
impl Default for SemanticCache {
fn default() -> Self {
Self::new(CacheConfig::default()).expect("Failed to create default semantic cache")
}
}