1use crate::types::{MemoryType, SearchResult};
10use dashmap::DashMap;
11use serde::{Deserialize, Serialize};
12use std::hash::{Hash, Hasher};
13use std::sync::atomic::{AtomicI64, AtomicU64, Ordering};
14use std::sync::Arc;
15use std::time::{Duration, Instant};
16
17#[derive(Debug, Clone, Default, PartialEq, Eq, Hash, Serialize, Deserialize)]
19pub struct CacheFilterParams {
20 pub workspace: Option<String>,
21 pub tier: Option<String>,
22 pub memory_types: Option<Vec<MemoryType>>,
23 pub include_archived: bool,
24 pub include_transcripts: bool,
25 pub tags: Option<Vec<String>>,
26}
27
28#[derive(Debug)]
30pub struct CachedSearchResult {
31 pub query_hash: u64,
33 pub query_embedding: Option<Vec<f32>>,
35 pub filter_params: CacheFilterParams,
37 pub results: Vec<SearchResult>,
39 pub created_at: Instant,
41 pub hit_count: AtomicU64,
43 pub feedback_score: AtomicI64,
45}
46
47impl CachedSearchResult {
48 pub fn new(
49 query_hash: u64,
50 query_embedding: Option<Vec<f32>>,
51 filter_params: CacheFilterParams,
52 results: Vec<SearchResult>,
53 ) -> Self {
54 Self {
55 query_hash,
56 query_embedding,
57 filter_params,
58 results,
59 created_at: Instant::now(),
60 hit_count: AtomicU64::new(0),
61 feedback_score: AtomicI64::new(0),
62 }
63 }
64
65 pub fn is_expired(&self, ttl: Duration) -> bool {
67 self.created_at.elapsed() > ttl
68 }
69
70 pub fn record_hit(&self) {
72 self.hit_count.fetch_add(1, Ordering::Relaxed);
73 }
74
75 pub fn record_feedback(&self, positive: bool) {
77 if positive {
78 self.feedback_score.fetch_add(1, Ordering::Relaxed);
79 } else {
80 self.feedback_score.fetch_sub(1, Ordering::Relaxed);
81 }
82 }
83}
84
85#[derive(Debug, Clone)]
87pub struct AdaptiveCacheConfig {
88 pub similarity_threshold: f32,
90 pub min_threshold: f32,
92 pub max_threshold: f32,
94 pub ttl_seconds: u64,
96 pub max_entries: usize,
98 pub adaptive_enabled: bool,
100}
101
102impl Default for AdaptiveCacheConfig {
103 fn default() -> Self {
104 Self {
105 similarity_threshold: 0.92,
106 min_threshold: 0.85,
107 max_threshold: 0.98,
108 ttl_seconds: 300, max_entries: 1000,
110 adaptive_enabled: true,
111 }
112 }
113}
114
115pub struct SearchResultCache {
117 entries: DashMap<String, Arc<CachedSearchResult>>,
119 config: AdaptiveCacheConfig,
121 current_threshold: std::sync::atomic::AtomicU32,
123 stats: CacheStats,
125}
126
127#[derive(Debug, Default)]
129pub struct CacheStats {
130 pub hits: AtomicU64,
131 pub misses: AtomicU64,
132 pub invalidations: AtomicU64,
133 pub evictions: AtomicU64,
134}
135
136impl CacheStats {
137 pub fn hit_rate(&self) -> f64 {
138 let hits = self.hits.load(Ordering::Relaxed);
139 let misses = self.misses.load(Ordering::Relaxed);
140 let total = hits + misses;
141 if total == 0 {
142 0.0
143 } else {
144 hits as f64 / total as f64
145 }
146 }
147}
148
149#[derive(Debug, Clone, Serialize, Deserialize)]
151pub struct CacheStatsResponse {
152 pub entries: usize,
153 pub hits: u64,
154 pub misses: u64,
155 pub hit_rate: f64,
156 pub invalidations: u64,
157 pub evictions: u64,
158 pub current_threshold: f32,
159 pub ttl_seconds: u64,
160}
161
162impl SearchResultCache {
163 pub fn new(config: AdaptiveCacheConfig) -> Self {
164 let threshold_bits = config.similarity_threshold.to_bits();
165 Self {
166 entries: DashMap::new(),
167 current_threshold: std::sync::atomic::AtomicU32::new(threshold_bits),
168 config,
169 stats: CacheStats::default(),
170 }
171 }
172
173 pub fn current_threshold(&self) -> f32 {
175 f32::from_bits(self.current_threshold.load(Ordering::Relaxed))
176 }
177
178 fn cache_key(query_hash: u64, filters: &CacheFilterParams) -> String {
180 let mut hasher = std::collections::hash_map::DefaultHasher::new();
181 query_hash.hash(&mut hasher);
182 filters.hash(&mut hasher);
183 format!("{:016x}", hasher.finish())
184 }
185
186 pub fn hash_query(query: &str) -> u64 {
188 let mut hasher = std::collections::hash_map::DefaultHasher::new();
189 query.to_lowercase().trim().hash(&mut hasher);
190 hasher.finish()
191 }
192
193 fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
195 if a.len() != b.len() || a.is_empty() {
196 return 0.0;
197 }
198
199 let mut dot = 0.0f32;
200 let mut norm_a = 0.0f32;
201 let mut norm_b = 0.0f32;
202
203 for (x, y) in a.iter().zip(b.iter()) {
204 dot += x * y;
205 norm_a += x * x;
206 norm_b += y * y;
207 }
208
209 if norm_a == 0.0 || norm_b == 0.0 {
210 return 0.0;
211 }
212
213 dot / (norm_a.sqrt() * norm_b.sqrt())
214 }
215
216 pub fn get(
218 &self,
219 query: &str,
220 query_embedding: Option<&[f32]>,
221 filters: &CacheFilterParams,
222 ) -> Option<Vec<SearchResult>> {
223 let query_hash = Self::hash_query(query);
224 let cache_key = Self::cache_key(query_hash, filters);
225
226 if let Some(entry) = self.entries.get(&cache_key) {
228 if !entry.is_expired(Duration::from_secs(self.config.ttl_seconds)) {
229 entry.record_hit();
230 self.stats.hits.fetch_add(1, Ordering::Relaxed);
231 return Some(entry.results.clone());
232 } else {
233 drop(entry);
235 self.entries.remove(&cache_key);
236 }
237 }
238
239 if let Some(embedding) = query_embedding {
241 let threshold = self.current_threshold();
242
243 for entry in self.entries.iter() {
244 if entry.filter_params != *filters {
245 continue;
246 }
247
248 if entry.is_expired(Duration::from_secs(self.config.ttl_seconds)) {
249 continue;
250 }
251
252 if let Some(ref cached_embedding) = entry.query_embedding {
253 let similarity = Self::cosine_similarity(embedding, cached_embedding);
254 if similarity >= threshold {
255 entry.record_hit();
256 self.stats.hits.fetch_add(1, Ordering::Relaxed);
257 return Some(entry.results.clone());
258 }
259 }
260 }
261 }
262
263 self.stats.misses.fetch_add(1, Ordering::Relaxed);
264 None
265 }
266
267 pub fn put(
269 &self,
270 query: &str,
271 query_embedding: Option<Vec<f32>>,
272 filters: CacheFilterParams,
273 results: Vec<SearchResult>,
274 ) {
275 let query_hash = Self::hash_query(query);
276 let cache_key = Self::cache_key(query_hash, &filters);
277
278 if self.entries.len() >= self.config.max_entries {
280 self.evict_oldest();
281 }
282
283 let entry = CachedSearchResult::new(query_hash, query_embedding, filters, results);
284 self.entries.insert(cache_key, Arc::new(entry));
285 }
286
287 fn evict_oldest(&self) {
289 let mut oldest_key: Option<String> = None;
290 let mut oldest_time = Instant::now();
291
292 for entry in self.entries.iter() {
293 if entry.created_at < oldest_time {
294 oldest_time = entry.created_at;
295 oldest_key = Some(entry.key().clone());
296 }
297 }
298
299 if let Some(key) = oldest_key {
300 self.entries.remove(&key);
301 self.stats.evictions.fetch_add(1, Ordering::Relaxed);
302 }
303 }
304
305 pub fn remove_expired(&self) {
307 let ttl = Duration::from_secs(self.config.ttl_seconds);
308 self.entries.retain(|_, v| !v.is_expired(ttl));
309 }
310
311 pub fn invalidate_for_workspace(&self, workspace: Option<&str>) {
313 self.entries.retain(|_, v| {
314 let should_keep = v.filter_params.workspace.as_deref() != workspace;
315 if !should_keep {
316 self.stats.invalidations.fetch_add(1, Ordering::Relaxed);
317 }
318 should_keep
319 });
320 }
321
322 pub fn invalidate_for_memory(&self, memory_id: i64) {
324 self.entries.retain(|_, v| {
330 let contains_memory = v.results.iter().any(|r| r.memory.id == memory_id);
332 if contains_memory {
333 self.stats.invalidations.fetch_add(1, Ordering::Relaxed);
334 }
335 !contains_memory
336 });
337 }
338
339 pub fn clear(&self) {
341 let count = self.entries.len();
342 self.entries.clear();
343 self.stats
344 .invalidations
345 .fetch_add(count as u64, Ordering::Relaxed);
346 }
347
348 pub fn record_feedback(&self, query: &str, filters: &CacheFilterParams, positive: bool) {
350 let query_hash = Self::hash_query(query);
351 let cache_key = Self::cache_key(query_hash, filters);
352
353 if let Some(entry) = self.entries.get(&cache_key) {
354 entry.record_feedback(positive);
355 }
356
357 if self.config.adaptive_enabled {
359 self.adjust_threshold(positive);
360 }
361 }
362
363 fn adjust_threshold(&self, positive: bool) {
365 let current = self.current_threshold();
366 let adjustment = 0.01; let new_threshold = if positive {
369 (current - adjustment).max(self.config.min_threshold)
371 } else {
372 (current + adjustment).min(self.config.max_threshold)
374 };
375
376 self.current_threshold
377 .store(new_threshold.to_bits(), Ordering::Relaxed);
378 }
379
380 pub fn stats(&self) -> CacheStatsResponse {
382 CacheStatsResponse {
383 entries: self.entries.len(),
384 hits: self.stats.hits.load(Ordering::Relaxed),
385 misses: self.stats.misses.load(Ordering::Relaxed),
386 hit_rate: self.stats.hit_rate(),
387 invalidations: self.stats.invalidations.load(Ordering::Relaxed),
388 evictions: self.stats.evictions.load(Ordering::Relaxed),
389 current_threshold: self.current_threshold(),
390 ttl_seconds: self.config.ttl_seconds,
391 }
392 }
393
394 pub fn start_expiration_worker(cache: Arc<Self>, interval_secs: u64) {
396 std::thread::spawn(move || loop {
397 std::thread::sleep(Duration::from_secs(interval_secs));
398 cache.remove_expired();
399 });
400 }
401}
402
403#[cfg(test)]
404mod tests {
405 use super::*;
406 use crate::types::MemoryType;
407
408 fn make_test_memory(id: i64, content: &str) -> crate::types::Memory {
409 crate::types::Memory {
410 id,
411 content: content.to_string(),
412 memory_type: MemoryType::Note,
413 importance: 0.5,
414 tags: vec![],
415 access_count: 0,
416 created_at: chrono::Utc::now(),
417 updated_at: chrono::Utc::now(),
418 last_accessed_at: None,
419 owner_id: None,
420 visibility: Default::default(),
421 version: 1,
422 has_embedding: false,
423 metadata: Default::default(),
424 scope: crate::types::MemoryScope::Global,
425 workspace: "default".to_string(),
426 tier: crate::types::MemoryTier::Permanent,
427 expires_at: None,
428 content_hash: None,
429 event_time: None,
430 event_duration_seconds: None,
431 trigger_pattern: None,
432 procedure_success_count: 0,
433 procedure_failure_count: 0,
434 summary_of_id: None,
435 lifecycle_state: crate::types::LifecycleState::Active,
436 media_url: None,
437 }
438 }
439
440 fn make_test_result(id: i64, content: &str, score: f32) -> SearchResult {
441 SearchResult {
442 memory: make_test_memory(id, content),
443 score,
444 match_info: crate::types::MatchInfo {
445 strategy: crate::types::SearchStrategy::Hybrid,
446 matched_terms: vec![],
447 highlights: vec![],
448 semantic_score: None,
449 keyword_score: Some(score),
450 },
451 }
452 }
453
454 #[test]
455 fn test_cache_put_get() {
456 let cache = SearchResultCache::new(AdaptiveCacheConfig::default());
457 let results = vec![make_test_result(1, "test content", 0.9)];
458
459 cache.put(
460 "test query",
461 None,
462 CacheFilterParams::default(),
463 results.clone(),
464 );
465
466 let cached = cache.get("test query", None, &CacheFilterParams::default());
467 assert!(cached.is_some());
468 assert_eq!(cached.unwrap().len(), 1);
469 }
470
471 #[test]
472 fn test_cache_miss() {
473 let cache = SearchResultCache::new(AdaptiveCacheConfig::default());
474
475 let cached = cache.get("nonexistent", None, &CacheFilterParams::default());
476 assert!(cached.is_none());
477 }
478
479 #[test]
480 fn test_cache_invalidation() {
481 let cache = SearchResultCache::new(AdaptiveCacheConfig::default());
482 let results = vec![make_test_result(1, "test", 0.9)];
483
484 cache.put("query", None, CacheFilterParams::default(), results);
485
486 assert!(cache
488 .get("query", None, &CacheFilterParams::default())
489 .is_some());
490
491 cache.invalidate_for_memory(1);
493
494 assert!(cache
496 .get("query", None, &CacheFilterParams::default())
497 .is_none());
498 }
499
500 #[test]
501 fn test_different_filters_different_cache() {
502 let cache = SearchResultCache::new(AdaptiveCacheConfig::default());
503 let results1 = vec![make_test_result(1, "result 1", 0.9)];
504 let results2 = vec![make_test_result(2, "result 2", 0.8)];
505
506 let filters1 = CacheFilterParams {
507 workspace: Some("ws1".to_string()),
508 ..Default::default()
509 };
510 let filters2 = CacheFilterParams {
511 workspace: Some("ws2".to_string()),
512 ..Default::default()
513 };
514
515 cache.put("query", None, filters1.clone(), results1);
516 cache.put("query", None, filters2.clone(), results2);
517
518 let cached1 = cache.get("query", None, &filters1);
519 let cached2 = cache.get("query", None, &filters2);
520
521 assert!(cached1.is_some());
522 assert!(cached2.is_some());
523 assert_eq!(cached1.unwrap()[0].memory.id, 1);
524 assert_eq!(cached2.unwrap()[0].memory.id, 2);
525 }
526
527 #[test]
528 fn test_similarity_lookup() {
529 let cache = SearchResultCache::new(AdaptiveCacheConfig {
530 similarity_threshold: 0.9,
531 ..Default::default()
532 });
533
534 let embedding = vec![1.0, 0.0, 0.0];
535 let results = vec![make_test_result(1, "test", 0.9)];
536
537 cache.put(
538 "original query",
539 Some(embedding.clone()),
540 CacheFilterParams::default(),
541 results,
542 );
543
544 let cached = cache.get(
546 "different query",
547 Some(&embedding),
548 &CacheFilterParams::default(),
549 );
550 assert!(cached.is_some());
551
552 let similar = vec![0.99, 0.1, 0.0];
554 let cached = cache.get(
555 "another query",
556 Some(&similar),
557 &CacheFilterParams::default(),
558 );
559 assert!(cached.is_some());
560
561 let different = vec![0.0, 1.0, 0.0];
563 let cached = cache.get(
564 "yet another",
565 Some(&different),
566 &CacheFilterParams::default(),
567 );
568 assert!(cached.is_none());
569 }
570
571 #[test]
572 fn test_stats() {
573 let cache = SearchResultCache::new(AdaptiveCacheConfig::default());
574 let results = vec![make_test_result(1, "test", 0.9)];
575
576 cache.get("query", None, &CacheFilterParams::default());
578
579 cache.put("query", None, CacheFilterParams::default(), results);
581
582 cache.get("query", None, &CacheFilterParams::default());
584 cache.get("query", None, &CacheFilterParams::default());
585
586 let stats = cache.stats();
587 assert_eq!(stats.entries, 1);
588 assert_eq!(stats.misses, 1);
589 assert_eq!(stats.hits, 2);
590 assert!(stats.hit_rate > 0.6);
591 }
592}