use crate::core::models::openai::*;
use crate::storage::vector::VectorStore;
use crate::utils::error::Result;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::{debug, info, warn};
use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SemanticCacheEntry {
pub id: String,
pub prompt_hash: String,
pub embedding: Vec<f32>,
pub response: ChatCompletionResponse,
pub model: String,
pub created_at: chrono::DateTime<chrono::Utc>,
pub last_accessed: chrono::DateTime<chrono::Utc>,
pub access_count: u64,
pub ttl_seconds: Option<u64>,
pub metadata: HashMap<String, String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SemanticCacheConfig {
pub similarity_threshold: f64,
pub max_cache_size: usize,
pub default_ttl_seconds: u64,
pub embedding_model: String,
pub enable_streaming_cache: bool,
pub min_prompt_length: usize,
pub cache_hit_boost: f64,
}
impl Default for SemanticCacheConfig {
fn default() -> Self {
Self {
similarity_threshold: 0.85,
max_cache_size: 10000,
default_ttl_seconds: 3600, embedding_model: "text-embedding-ada-002".to_string(),
enable_streaming_cache: false,
min_prompt_length: 10,
cache_hit_boost: 1.1,
}
}
}
pub struct SemanticCache {
config: SemanticCacheConfig,
vector_store: Arc<dyn VectorStore>,
memory_cache: Arc<RwLock<HashMap<String, SemanticCacheEntry>>>,
embedding_provider: Arc<dyn EmbeddingProvider>,
stats: Arc<RwLock<CacheStats>>,
}
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
pub struct CacheStats {
pub hits: u64,
pub misses: u64,
pub total_entries: u64,
pub avg_hit_similarity: f64,
pub cache_size_bytes: u64,
}
#[async_trait::async_trait]
pub trait EmbeddingProvider: Send + Sync {
async fn generate_embedding(&self, text: &str) -> Result<Vec<f32>>;
fn embedding_dimension(&self) -> usize;
}
impl SemanticCache {
pub async fn new(
config: SemanticCacheConfig,
vector_store: Arc<dyn VectorStore>,
embedding_provider: Arc<dyn EmbeddingProvider>,
) -> Result<Self> {
info!(
"Initializing semantic cache with threshold: {}",
config.similarity_threshold
);
Ok(Self {
config,
vector_store,
memory_cache: Arc::new(RwLock::new(HashMap::new())),
embedding_provider,
stats: Arc::new(RwLock::new(CacheStats::default())),
})
}
pub async fn get_cached_response(
&self,
request: &ChatCompletionRequest,
) -> Result<Option<ChatCompletionResponse>> {
if !self.should_cache_request(request) {
return Ok(None);
}
let prompt_text = self.extract_prompt_text(&request.messages);
if prompt_text.len() < self.config.min_prompt_length {
debug!("Prompt too short for caching: {} chars", prompt_text.len());
return Ok(None);
}
let embedding = match self
.embedding_provider
.generate_embedding(&prompt_text)
.await
{
Ok(emb) => emb,
Err(e) => {
warn!("Failed to generate embedding for cache lookup: {}", e);
return Ok(None);
}
};
let search_results = self.vector_store.search(embedding, 10).await?;
for result in search_results {
if result.score >= self.config.similarity_threshold as f32 {
if let Some(entry) = self.get_cache_entry(&result.id).await? {
if self.is_entry_valid(&entry) {
self.update_access_stats(&result.id, result.score as f64)
.await?;
let mut stats = self.stats.write().await;
stats.hits += 1;
stats.avg_hit_similarity = (stats.avg_hit_similarity
* (stats.hits - 1) as f64
+ result.score as f64)
/ stats.hits as f64;
info!(
"Cache hit! Similarity: {:.3}, Entry: {}",
result.score, result.id
);
return Ok(Some(entry.response));
} else {
self.remove_cache_entry(&result.id).await?;
}
}
}
}
let mut stats = self.stats.write().await;
stats.misses += 1;
debug!(
"Cache miss for prompt: {}",
prompt_text.chars().take(100).collect::<String>()
);
Ok(None)
}
pub async fn cache_response(
&self,
request: &ChatCompletionRequest,
response: &ChatCompletionResponse,
) -> Result<()> {
if !self.should_cache_request(request) {
return Ok(());
}
let prompt_text = self.extract_prompt_text(&request.messages);
if prompt_text.len() < self.config.min_prompt_length {
return Ok(());
}
let embedding = self
.embedding_provider
.generate_embedding(&prompt_text)
.await?;
let entry = SemanticCacheEntry {
id: Uuid::new_v4().to_string(),
prompt_hash: self.hash_prompt(&prompt_text),
embedding: embedding.clone(),
response: response.clone(),
model: request.model.clone(),
created_at: chrono::Utc::now(),
last_accessed: chrono::Utc::now(),
access_count: 0,
ttl_seconds: Some(self.config.default_ttl_seconds),
metadata: HashMap::new(),
};
let vector_data = crate::storage::vector::VectorData {
id: entry.id.clone(),
vector: embedding,
metadata: {
let mut metadata = HashMap::new();
metadata.insert(
"prompt_hash".to_string(),
serde_json::to_value(&entry.prompt_hash)?,
);
metadata.insert(
"created_at".to_string(),
serde_json::to_value(&entry.created_at)?,
);
metadata
},
};
self.vector_store.insert(vec![vector_data]).await?;
let mut memory_cache = self.memory_cache.write().await;
memory_cache.insert(entry.id.clone(), entry);
let mut stats = self.stats.write().await;
stats.total_entries += 1;
if memory_cache.len() > self.config.max_cache_size {
self.evict_old_entries().await?;
}
info!("Cached response for model: {}", request.model);
Ok(())
}
fn should_cache_request(&self, request: &ChatCompletionRequest) -> bool {
if request.stream.unwrap_or(false) && !self.config.enable_streaming_cache {
return false;
}
if request.tools.is_some() || request.tool_choice.is_some() {
return false;
}
if let Some(temperature) = request.temperature {
if temperature > 0.7 {
return false;
}
}
true
}
fn extract_prompt_text(&self, messages: &[ChatMessage]) -> String {
messages
.iter()
.filter_map(|msg| match &msg.content {
Some(MessageContent::Text(text)) => Some(text.clone()),
Some(MessageContent::Parts(parts)) => {
let text = parts
.iter()
.filter_map(|part| match part {
ContentPart::Text { text } => Some(text.clone()),
_ => None,
})
.collect::<Vec<String>>()
.join(" ");
if text.is_empty() { None } else { Some(text) }
}
None => None,
})
.collect::<Vec<String>>()
.join("\n")
}
fn hash_prompt(&self, prompt: &str) -> String {
use sha2::{Digest, Sha256};
let mut hasher = Sha256::new();
hasher.update(prompt.as_bytes());
format!("{:x}", hasher.finalize())
}
async fn get_cache_entry(&self, entry_id: &str) -> Result<Option<SemanticCacheEntry>> {
{
let memory_cache = self.memory_cache.read().await;
if let Some(entry) = memory_cache.get(entry_id) {
return Ok(Some(entry.clone()));
}
}
Ok(None)
}
fn is_entry_valid(&self, entry: &SemanticCacheEntry) -> bool {
if let Some(ttl_seconds) = entry.ttl_seconds {
let expiry_time = entry.created_at + chrono::Duration::seconds(ttl_seconds as i64);
chrono::Utc::now() < expiry_time
} else {
true }
}
async fn update_access_stats(&self, entry_id: &str, _similarity: f64) -> Result<()> {
{
let mut memory_cache = self.memory_cache.write().await;
if let Some(entry) = memory_cache.get_mut(entry_id) {
entry.last_accessed = chrono::Utc::now();
entry.access_count += 1;
}
}
Ok(())
}
async fn remove_cache_entry(&self, entry_id: &str) -> Result<()> {
{
let mut memory_cache = self.memory_cache.write().await;
memory_cache.remove(entry_id);
}
self.vector_store.delete(vec![entry_id.to_string()]).await?;
Ok(())
}
async fn evict_old_entries(&self) -> Result<()> {
let mut memory_cache = self.memory_cache.write().await;
let mut entries: Vec<_> = memory_cache
.iter()
.map(|(k, v)| (k.clone(), v.last_accessed))
.collect();
entries.sort_by_key(|(_, last_accessed)| *last_accessed);
let evict_count = (entries.len() as f64 * 0.1).ceil() as usize;
let entries_to_remove: Vec<String> = entries
.iter()
.take(evict_count)
.map(|(id, _)| id.clone())
.collect();
for entry_id in entries_to_remove {
memory_cache.remove(&entry_id);
let vector_store = self.vector_store.clone();
let entry_id_clone = entry_id.clone();
tokio::spawn(async move {
if let Err(e) = vector_store.delete(vec![entry_id_clone]).await {
warn!("Failed to delete entry from vector store: {}", e);
}
});
}
info!("Evicted {} old cache entries", evict_count);
Ok(())
}
pub async fn get_stats(&self) -> CacheStats {
self.stats.read().await.clone()
}
pub async fn clear_cache(&self) -> Result<()> {
{
let mut memory_cache = self.memory_cache.write().await;
memory_cache.clear();
}
{
let mut stats = self.stats.write().await;
*stats = CacheStats::default();
}
info!("Cleared all cache entries");
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::models::openai::{MessageContent, MessageRole};
#[test]
fn test_semantic_cache_config_default() {
let config = SemanticCacheConfig::default();
assert_eq!(config.similarity_threshold, 0.85);
assert_eq!(config.max_cache_size, 10000);
assert_eq!(config.default_ttl_seconds, 3600);
}
#[tokio::test]
async fn test_extract_prompt_text() {
let cache = create_test_cache().await;
let messages = vec![
ChatMessage {
role: MessageRole::System,
content: Some(MessageContent::Text(
"You are a helpful assistant".to_string(),
)),
name: None,
function_call: None,
tool_calls: None,
tool_call_id: None,
audio: None,
},
ChatMessage {
role: MessageRole::User,
content: Some(MessageContent::Text("Hello world".to_string())),
name: None,
function_call: None,
tool_calls: None,
tool_call_id: None,
audio: None,
},
];
let prompt_text = cache.extract_prompt_text(&messages);
assert!(prompt_text.contains("You are a helpful assistant"));
assert!(prompt_text.contains("Hello world"));
}
#[tokio::test]
async fn test_should_cache_request() {
let cache = create_test_cache().await;
let mut request = ChatCompletionRequest {
model: "gpt-4".to_string(),
messages: vec![],
max_tokens: None,
max_completion_tokens: None,
temperature: Some(0.1),
top_p: None,
n: None,
stream: Some(false),
stream_options: None,
stop: None,
presence_penalty: None,
frequency_penalty: None,
logit_bias: None,
user: None,
functions: None,
function_call: None,
tools: None,
tool_choice: None,
response_format: None,
seed: None,
logprobs: None,
top_logprobs: None,
modalities: None,
audio: None,
};
assert!(cache.should_cache_request(&request));
request.temperature = Some(0.9);
assert!(!cache.should_cache_request(&request));
request.temperature = Some(0.1);
request.stream = Some(true);
assert!(!cache.should_cache_request(&request));
}
async fn create_test_cache() -> SemanticCache {
let config = SemanticCacheConfig {
similarity_threshold: 0.85,
max_cache_size: 1000,
default_ttl_seconds: 3600,
embedding_model: "text-embedding-ada-002".to_string(),
enable_streaming_cache: false,
min_prompt_length: 10,
cache_hit_boost: 1.1,
};
SemanticCache {
config,
vector_store: Arc::new(TestVectorStore),
memory_cache: Arc::new(RwLock::new(HashMap::new())),
embedding_provider: Arc::new(TestEmbeddingProvider),
stats: Arc::new(RwLock::new(CacheStats::default())),
}
}
struct TestVectorStore;
struct TestEmbeddingProvider;
#[async_trait::async_trait]
impl VectorStore for TestVectorStore {
async fn search(
&self,
_vector: Vec<f32>,
_limit: usize,
) -> Result<Vec<crate::storage::vector::SearchResult>> {
Ok(vec![])
}
async fn insert(&self, _vectors: Vec<crate::storage::vector::VectorData>) -> Result<()> {
Ok(())
}
async fn delete(&self, _ids: Vec<String>) -> Result<()> {
Ok(())
}
}
#[async_trait::async_trait]
impl EmbeddingProvider for TestEmbeddingProvider {
async fn generate_embedding(&self, _text: &str) -> Result<Vec<f32>> {
Ok(vec![0.1; 1536])
}
fn embedding_dimension(&self) -> usize {
1536
}
}
}