use std::time::{Duration, Instant};
use dashmap::DashMap;
use crate::prompts::chat::ChatPromptClient;
use crate::prompts::text::TextPromptClient;
struct CacheEntry<T> {
value: T,
inserted_at: Instant,
ttl: Duration,
}
impl<T> CacheEntry<T> {
fn is_expired(&self) -> bool {
self.inserted_at.elapsed() >= self.ttl
}
}
pub struct PromptCache {
text_entries: DashMap<String, CacheEntry<TextPromptClient>>,
chat_entries: DashMap<String, CacheEntry<ChatPromptClient>>,
default_ttl: Duration,
}
impl PromptCache {
pub fn new(default_ttl: Duration) -> Self {
Self {
text_entries: DashMap::new(),
chat_entries: DashMap::new(),
default_ttl,
}
}
pub fn get_text(&self, key: &str) -> Option<TextPromptClient> {
let entry = self.text_entries.get(key)?;
if entry.is_expired() {
drop(entry);
self.text_entries.remove(key);
return None;
}
Some(entry.value.clone())
}
pub fn put_text(&self, key: &str, prompt: TextPromptClient) {
self.text_entries.insert(
key.to_owned(),
CacheEntry {
value: prompt,
inserted_at: Instant::now(),
ttl: self.default_ttl,
},
);
}
pub fn get_text_expired(&self, key: &str) -> Option<TextPromptClient> {
self.text_entries.get(key).map(|entry| entry.value.clone())
}
pub fn get_chat(&self, key: &str) -> Option<ChatPromptClient> {
let entry = self.chat_entries.get(key)?;
if entry.is_expired() {
drop(entry);
self.chat_entries.remove(key);
return None;
}
Some(entry.value.clone())
}
pub fn put_chat(&self, key: &str, prompt: ChatPromptClient) {
self.chat_entries.insert(
key.to_owned(),
CacheEntry {
value: prompt,
inserted_at: Instant::now(),
ttl: self.default_ttl,
},
);
}
pub fn get_chat_expired(&self, key: &str) -> Option<ChatPromptClient> {
self.chat_entries.get(key).map(|entry| entry.value.clone())
}
pub fn clear(&self) {
self.text_entries.clear();
self.chat_entries.clear();
}
pub fn invalidate_by_prefix(&self, prefix: &str) {
self.text_entries.retain(|k, _| !k.starts_with(prefix));
self.chat_entries.retain(|k, _| !k.starts_with(prefix));
}
}