use crate::types::{MemoryType, SearchResult};
use dashmap::DashMap;
use serde::{Deserialize, Serialize};
use std::hash::{Hash, Hasher};
use std::sync::atomic::{AtomicI64, AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
#[derive(Debug, Clone, Default, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct CacheFilterParams {
pub workspace: Option<String>,
pub tier: Option<String>,
pub memory_types: Option<Vec<MemoryType>>,
pub include_archived: bool,
pub include_transcripts: bool,
pub tags: Option<Vec<String>>,
}
#[derive(Debug)]
pub struct CachedSearchResult {
pub query_hash: u64,
pub query_embedding: Option<Vec<f32>>,
pub filter_params: CacheFilterParams,
pub results: Vec<SearchResult>,
pub created_at: Instant,
pub hit_count: AtomicU64,
pub feedback_score: AtomicI64,
}
impl CachedSearchResult {
pub fn new(
query_hash: u64,
query_embedding: Option<Vec<f32>>,
filter_params: CacheFilterParams,
results: Vec<SearchResult>,
) -> Self {
Self {
query_hash,
query_embedding,
filter_params,
results,
created_at: Instant::now(),
hit_count: AtomicU64::new(0),
feedback_score: AtomicI64::new(0),
}
}
pub fn is_expired(&self, ttl: Duration) -> bool {
self.created_at.elapsed() > ttl
}
pub fn record_hit(&self) {
self.hit_count.fetch_add(1, Ordering::Relaxed);
}
pub fn record_feedback(&self, positive: bool) {
if positive {
self.feedback_score.fetch_add(1, Ordering::Relaxed);
} else {
self.feedback_score.fetch_sub(1, Ordering::Relaxed);
}
}
}
#[derive(Debug, Clone)]
pub struct AdaptiveCacheConfig {
pub similarity_threshold: f32,
pub min_threshold: f32,
pub max_threshold: f32,
pub ttl_seconds: u64,
pub max_entries: usize,
pub adaptive_enabled: bool,
}
impl Default for AdaptiveCacheConfig {
fn default() -> Self {
Self {
similarity_threshold: 0.92,
min_threshold: 0.85,
max_threshold: 0.98,
ttl_seconds: 300, max_entries: 1000,
adaptive_enabled: true,
}
}
}
pub struct SearchResultCache {
entries: DashMap<String, Arc<CachedSearchResult>>,
config: AdaptiveCacheConfig,
current_threshold: std::sync::atomic::AtomicU32,
stats: CacheStats,
}
#[derive(Debug, Default)]
pub struct CacheStats {
pub hits: AtomicU64,
pub misses: AtomicU64,
pub invalidations: AtomicU64,
pub evictions: AtomicU64,
}
impl CacheStats {
pub fn hit_rate(&self) -> f64 {
let hits = self.hits.load(Ordering::Relaxed);
let misses = self.misses.load(Ordering::Relaxed);
let total = hits + misses;
if total == 0 {
0.0
} else {
hits as f64 / total as f64
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CacheStatsResponse {
pub entries: usize,
pub hits: u64,
pub misses: u64,
pub hit_rate: f64,
pub invalidations: u64,
pub evictions: u64,
pub current_threshold: f32,
pub ttl_seconds: u64,
}
impl SearchResultCache {
pub fn new(config: AdaptiveCacheConfig) -> Self {
let threshold_bits = config.similarity_threshold.to_bits();
Self {
entries: DashMap::new(),
current_threshold: std::sync::atomic::AtomicU32::new(threshold_bits),
config,
stats: CacheStats::default(),
}
}
pub fn current_threshold(&self) -> f32 {
f32::from_bits(self.current_threshold.load(Ordering::Relaxed))
}
fn cache_key(query_hash: u64, filters: &CacheFilterParams) -> String {
let mut hasher = std::collections::hash_map::DefaultHasher::new();
query_hash.hash(&mut hasher);
filters.hash(&mut hasher);
format!("{:016x}", hasher.finish())
}
pub fn hash_query(query: &str) -> u64 {
let mut hasher = std::collections::hash_map::DefaultHasher::new();
query.to_lowercase().trim().hash(&mut hasher);
hasher.finish()
}
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() || a.is_empty() {
return 0.0;
}
let mut dot = 0.0f32;
let mut norm_a = 0.0f32;
let mut norm_b = 0.0f32;
for (x, y) in a.iter().zip(b.iter()) {
dot += x * y;
norm_a += x * x;
norm_b += y * y;
}
if norm_a == 0.0 || norm_b == 0.0 {
return 0.0;
}
dot / (norm_a.sqrt() * norm_b.sqrt())
}
pub fn get(
&self,
query: &str,
query_embedding: Option<&[f32]>,
filters: &CacheFilterParams,
) -> Option<Vec<SearchResult>> {
let query_hash = Self::hash_query(query);
let cache_key = Self::cache_key(query_hash, filters);
if let Some(entry) = self.entries.get(&cache_key) {
if !entry.is_expired(Duration::from_secs(self.config.ttl_seconds)) {
entry.record_hit();
self.stats.hits.fetch_add(1, Ordering::Relaxed);
return Some(entry.results.clone());
} else {
drop(entry);
self.entries.remove(&cache_key);
}
}
if let Some(embedding) = query_embedding {
let threshold = self.current_threshold();
for entry in self.entries.iter() {
if entry.filter_params != *filters {
continue;
}
if entry.is_expired(Duration::from_secs(self.config.ttl_seconds)) {
continue;
}
if let Some(ref cached_embedding) = entry.query_embedding {
let similarity = Self::cosine_similarity(embedding, cached_embedding);
if similarity >= threshold {
entry.record_hit();
self.stats.hits.fetch_add(1, Ordering::Relaxed);
return Some(entry.results.clone());
}
}
}
}
self.stats.misses.fetch_add(1, Ordering::Relaxed);
None
}
pub fn put(
&self,
query: &str,
query_embedding: Option<Vec<f32>>,
filters: CacheFilterParams,
results: Vec<SearchResult>,
) {
let query_hash = Self::hash_query(query);
let cache_key = Self::cache_key(query_hash, &filters);
if self.entries.len() >= self.config.max_entries {
self.evict_oldest();
}
let entry = CachedSearchResult::new(query_hash, query_embedding, filters, results);
self.entries.insert(cache_key, Arc::new(entry));
}
fn evict_oldest(&self) {
let mut oldest_key: Option<String> = None;
let mut oldest_time = Instant::now();
for entry in self.entries.iter() {
if entry.created_at < oldest_time {
oldest_time = entry.created_at;
oldest_key = Some(entry.key().clone());
}
}
if let Some(key) = oldest_key {
self.entries.remove(&key);
self.stats.evictions.fetch_add(1, Ordering::Relaxed);
}
}
pub fn remove_expired(&self) {
let ttl = Duration::from_secs(self.config.ttl_seconds);
self.entries.retain(|_, v| !v.is_expired(ttl));
}
pub fn invalidate_for_workspace(&self, workspace: Option<&str>) {
self.entries.retain(|_, v| {
let should_keep = v.filter_params.workspace.as_deref() != workspace;
if !should_keep {
self.stats.invalidations.fetch_add(1, Ordering::Relaxed);
}
should_keep
});
}
pub fn invalidate_for_memory(&self, memory_id: i64) {
self.entries.retain(|_, v| {
let contains_memory = v.results.iter().any(|r| r.memory.id == memory_id);
if contains_memory {
self.stats.invalidations.fetch_add(1, Ordering::Relaxed);
}
!contains_memory
});
}
pub fn clear(&self) {
let count = self.entries.len();
self.entries.clear();
self.stats
.invalidations
.fetch_add(count as u64, Ordering::Relaxed);
}
pub fn record_feedback(&self, query: &str, filters: &CacheFilterParams, positive: bool) {
let query_hash = Self::hash_query(query);
let cache_key = Self::cache_key(query_hash, filters);
if let Some(entry) = self.entries.get(&cache_key) {
entry.record_feedback(positive);
}
if self.config.adaptive_enabled {
self.adjust_threshold(positive);
}
}
fn adjust_threshold(&self, positive: bool) {
let current = self.current_threshold();
let adjustment = 0.01;
let new_threshold = if positive {
(current - adjustment).max(self.config.min_threshold)
} else {
(current + adjustment).min(self.config.max_threshold)
};
self.current_threshold
.store(new_threshold.to_bits(), Ordering::Relaxed);
}
pub fn stats(&self) -> CacheStatsResponse {
CacheStatsResponse {
entries: self.entries.len(),
hits: self.stats.hits.load(Ordering::Relaxed),
misses: self.stats.misses.load(Ordering::Relaxed),
hit_rate: self.stats.hit_rate(),
invalidations: self.stats.invalidations.load(Ordering::Relaxed),
evictions: self.stats.evictions.load(Ordering::Relaxed),
current_threshold: self.current_threshold(),
ttl_seconds: self.config.ttl_seconds,
}
}
pub fn start_expiration_worker(cache: Arc<Self>, interval_secs: u64) {
std::thread::spawn(move || loop {
std::thread::sleep(Duration::from_secs(interval_secs));
cache.remove_expired();
});
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::MemoryType;
fn make_test_memory(id: i64, content: &str) -> crate::types::Memory {
crate::types::Memory {
id,
content: content.to_string(),
memory_type: MemoryType::Note,
importance: 0.5,
tags: vec![],
access_count: 0,
created_at: chrono::Utc::now(),
updated_at: chrono::Utc::now(),
last_accessed_at: None,
owner_id: None,
visibility: Default::default(),
version: 1,
has_embedding: false,
metadata: Default::default(),
scope: crate::types::MemoryScope::Global,
workspace: "default".to_string(),
tier: crate::types::MemoryTier::Permanent,
expires_at: None,
content_hash: None,
event_time: None,
event_duration_seconds: None,
trigger_pattern: None,
procedure_success_count: 0,
procedure_failure_count: 0,
summary_of_id: None,
lifecycle_state: crate::types::LifecycleState::Active,
media_url: None,
}
}
fn make_test_result(id: i64, content: &str, score: f32) -> SearchResult {
SearchResult {
memory: make_test_memory(id, content),
score,
match_info: crate::types::MatchInfo {
strategy: crate::types::SearchStrategy::Hybrid,
matched_terms: vec![],
highlights: vec![],
semantic_score: None,
keyword_score: Some(score),
},
}
}
#[test]
fn test_cache_put_get() {
let cache = SearchResultCache::new(AdaptiveCacheConfig::default());
let results = vec![make_test_result(1, "test content", 0.9)];
cache.put(
"test query",
None,
CacheFilterParams::default(),
results.clone(),
);
let cached = cache.get("test query", None, &CacheFilterParams::default());
assert!(cached.is_some());
assert_eq!(cached.unwrap().len(), 1);
}
#[test]
fn test_cache_miss() {
let cache = SearchResultCache::new(AdaptiveCacheConfig::default());
let cached = cache.get("nonexistent", None, &CacheFilterParams::default());
assert!(cached.is_none());
}
#[test]
fn test_cache_invalidation() {
let cache = SearchResultCache::new(AdaptiveCacheConfig::default());
let results = vec![make_test_result(1, "test", 0.9)];
cache.put("query", None, CacheFilterParams::default(), results);
assert!(cache
.get("query", None, &CacheFilterParams::default())
.is_some());
cache.invalidate_for_memory(1);
assert!(cache
.get("query", None, &CacheFilterParams::default())
.is_none());
}
#[test]
fn test_different_filters_different_cache() {
let cache = SearchResultCache::new(AdaptiveCacheConfig::default());
let results1 = vec![make_test_result(1, "result 1", 0.9)];
let results2 = vec![make_test_result(2, "result 2", 0.8)];
let filters1 = CacheFilterParams {
workspace: Some("ws1".to_string()),
..Default::default()
};
let filters2 = CacheFilterParams {
workspace: Some("ws2".to_string()),
..Default::default()
};
cache.put("query", None, filters1.clone(), results1);
cache.put("query", None, filters2.clone(), results2);
let cached1 = cache.get("query", None, &filters1);
let cached2 = cache.get("query", None, &filters2);
assert!(cached1.is_some());
assert!(cached2.is_some());
assert_eq!(cached1.unwrap()[0].memory.id, 1);
assert_eq!(cached2.unwrap()[0].memory.id, 2);
}
#[test]
fn test_similarity_lookup() {
let cache = SearchResultCache::new(AdaptiveCacheConfig {
similarity_threshold: 0.9,
..Default::default()
});
let embedding = vec![1.0, 0.0, 0.0];
let results = vec![make_test_result(1, "test", 0.9)];
cache.put(
"original query",
Some(embedding.clone()),
CacheFilterParams::default(),
results,
);
let cached = cache.get(
"different query",
Some(&embedding),
&CacheFilterParams::default(),
);
assert!(cached.is_some());
let similar = vec![0.99, 0.1, 0.0];
let cached = cache.get(
"another query",
Some(&similar),
&CacheFilterParams::default(),
);
assert!(cached.is_some());
let different = vec![0.0, 1.0, 0.0];
let cached = cache.get(
"yet another",
Some(&different),
&CacheFilterParams::default(),
);
assert!(cached.is_none());
}
#[test]
fn test_stats() {
let cache = SearchResultCache::new(AdaptiveCacheConfig::default());
let results = vec![make_test_result(1, "test", 0.9)];
cache.get("query", None, &CacheFilterParams::default());
cache.put("query", None, CacheFilterParams::default(), results);
cache.get("query", None, &CacheFilterParams::default());
cache.get("query", None, &CacheFilterParams::default());
let stats = cache.stats();
assert_eq!(stats.entries, 1);
assert_eq!(stats.misses, 1);
assert_eq!(stats.hits, 2);
assert!(stats.hit_rate > 0.6);
}
}