1use crate::types::{MemoryType, SearchResult};
10use dashmap::DashMap;
11use serde::{Deserialize, Serialize};
12use std::hash::{Hash, Hasher};
13use std::sync::atomic::{AtomicI64, AtomicU64, Ordering};
14use std::sync::Arc;
15use std::time::{Duration, Instant};
16
17#[derive(Debug, Clone, Default, PartialEq, Eq, Hash, Serialize, Deserialize)]
19pub struct CacheFilterParams {
20 pub workspace: Option<String>,
21 pub tier: Option<String>,
22 pub memory_types: Option<Vec<MemoryType>>,
23 pub include_archived: bool,
24 pub include_transcripts: bool,
25 pub tags: Option<Vec<String>>,
26}
27
28#[derive(Debug)]
30pub struct CachedSearchResult {
31 pub query_hash: u64,
33 pub query_embedding: Option<Vec<f32>>,
35 pub filter_params: CacheFilterParams,
37 pub results: Vec<SearchResult>,
39 pub created_at: Instant,
41 pub hit_count: AtomicU64,
43 pub feedback_score: AtomicI64,
45}
46
47impl CachedSearchResult {
48 pub fn new(
49 query_hash: u64,
50 query_embedding: Option<Vec<f32>>,
51 filter_params: CacheFilterParams,
52 results: Vec<SearchResult>,
53 ) -> Self {
54 Self {
55 query_hash,
56 query_embedding,
57 filter_params,
58 results,
59 created_at: Instant::now(),
60 hit_count: AtomicU64::new(0),
61 feedback_score: AtomicI64::new(0),
62 }
63 }
64
65 pub fn is_expired(&self, ttl: Duration) -> bool {
67 self.created_at.elapsed() > ttl
68 }
69
70 pub fn record_hit(&self) {
72 self.hit_count.fetch_add(1, Ordering::Relaxed);
73 }
74
75 pub fn record_feedback(&self, positive: bool) {
77 if positive {
78 self.feedback_score.fetch_add(1, Ordering::Relaxed);
79 } else {
80 self.feedback_score.fetch_sub(1, Ordering::Relaxed);
81 }
82 }
83}
84
85#[derive(Debug, Clone)]
87pub struct AdaptiveCacheConfig {
88 pub similarity_threshold: f32,
90 pub min_threshold: f32,
92 pub max_threshold: f32,
94 pub ttl_seconds: u64,
96 pub max_entries: usize,
98 pub adaptive_enabled: bool,
100}
101
102impl Default for AdaptiveCacheConfig {
103 fn default() -> Self {
104 Self {
105 similarity_threshold: 0.92,
106 min_threshold: 0.85,
107 max_threshold: 0.98,
108 ttl_seconds: 300, max_entries: 1000,
110 adaptive_enabled: true,
111 }
112 }
113}
114
115pub struct SearchResultCache {
117 entries: DashMap<String, Arc<CachedSearchResult>>,
119 config: AdaptiveCacheConfig,
121 current_threshold: std::sync::atomic::AtomicU32,
123 stats: CacheStats,
125}
126
127#[derive(Debug, Default)]
129pub struct CacheStats {
130 pub hits: AtomicU64,
131 pub misses: AtomicU64,
132 pub invalidations: AtomicU64,
133 pub evictions: AtomicU64,
134}
135
136impl CacheStats {
137 pub fn hit_rate(&self) -> f64 {
138 let hits = self.hits.load(Ordering::Relaxed);
139 let misses = self.misses.load(Ordering::Relaxed);
140 let total = hits + misses;
141 if total == 0 {
142 0.0
143 } else {
144 hits as f64 / total as f64
145 }
146 }
147}
148
149#[derive(Debug, Clone, Serialize, Deserialize)]
151pub struct CacheStatsResponse {
152 pub entries: usize,
153 pub hits: u64,
154 pub misses: u64,
155 pub hit_rate: f64,
156 pub invalidations: u64,
157 pub evictions: u64,
158 pub current_threshold: f32,
159 pub ttl_seconds: u64,
160}
161
162impl SearchResultCache {
163 pub fn new(config: AdaptiveCacheConfig) -> Self {
164 let threshold_bits = config.similarity_threshold.to_bits();
165 Self {
166 entries: DashMap::new(),
167 current_threshold: std::sync::atomic::AtomicU32::new(threshold_bits),
168 config,
169 stats: CacheStats::default(),
170 }
171 }
172
173 pub fn current_threshold(&self) -> f32 {
175 f32::from_bits(self.current_threshold.load(Ordering::Relaxed))
176 }
177
178 fn cache_key(query_hash: u64, filters: &CacheFilterParams) -> String {
180 let mut hasher = std::collections::hash_map::DefaultHasher::new();
181 query_hash.hash(&mut hasher);
182 filters.hash(&mut hasher);
183 format!("{:016x}", hasher.finish())
184 }
185
186 pub fn hash_query(query: &str) -> u64 {
188 let mut hasher = std::collections::hash_map::DefaultHasher::new();
189 query.to_lowercase().trim().hash(&mut hasher);
190 hasher.finish()
191 }
192
193 fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
195 if a.len() != b.len() || a.is_empty() {
196 return 0.0;
197 }
198
199 let mut dot = 0.0f32;
200 let mut norm_a = 0.0f32;
201 let mut norm_b = 0.0f32;
202
203 for (x, y) in a.iter().zip(b.iter()) {
204 dot += x * y;
205 norm_a += x * x;
206 norm_b += y * y;
207 }
208
209 if norm_a == 0.0 || norm_b == 0.0 {
210 return 0.0;
211 }
212
213 dot / (norm_a.sqrt() * norm_b.sqrt())
214 }
215
216 pub fn get(
218 &self,
219 query: &str,
220 query_embedding: Option<&[f32]>,
221 filters: &CacheFilterParams,
222 ) -> Option<Vec<SearchResult>> {
223 let query_hash = Self::hash_query(query);
224 let cache_key = Self::cache_key(query_hash, filters);
225
226 if let Some(entry) = self.entries.get(&cache_key) {
228 if !entry.is_expired(Duration::from_secs(self.config.ttl_seconds)) {
229 entry.record_hit();
230 self.stats.hits.fetch_add(1, Ordering::Relaxed);
231 return Some(entry.results.clone());
232 } else {
233 drop(entry);
235 self.entries.remove(&cache_key);
236 }
237 }
238
239 if let Some(embedding) = query_embedding {
241 let threshold = self.current_threshold();
242
243 for entry in self.entries.iter() {
244 if entry.filter_params != *filters {
245 continue;
246 }
247
248 if entry.is_expired(Duration::from_secs(self.config.ttl_seconds)) {
249 continue;
250 }
251
252 if let Some(ref cached_embedding) = entry.query_embedding {
253 let similarity = Self::cosine_similarity(embedding, cached_embedding);
254 if similarity >= threshold {
255 entry.record_hit();
256 self.stats.hits.fetch_add(1, Ordering::Relaxed);
257 return Some(entry.results.clone());
258 }
259 }
260 }
261 }
262
263 self.stats.misses.fetch_add(1, Ordering::Relaxed);
264 None
265 }
266
267 pub fn put(
269 &self,
270 query: &str,
271 query_embedding: Option<Vec<f32>>,
272 filters: CacheFilterParams,
273 results: Vec<SearchResult>,
274 ) {
275 let query_hash = Self::hash_query(query);
276 let cache_key = Self::cache_key(query_hash, &filters);
277
278 if self.entries.len() >= self.config.max_entries {
280 self.evict_oldest();
281 }
282
283 let entry = CachedSearchResult::new(query_hash, query_embedding, filters, results);
284 self.entries.insert(cache_key, Arc::new(entry));
285 }
286
287 fn evict_oldest(&self) {
289 let mut oldest_key: Option<String> = None;
290 let mut oldest_time = Instant::now();
291
292 for entry in self.entries.iter() {
293 if entry.created_at < oldest_time {
294 oldest_time = entry.created_at;
295 oldest_key = Some(entry.key().clone());
296 }
297 }
298
299 if let Some(key) = oldest_key {
300 self.entries.remove(&key);
301 self.stats.evictions.fetch_add(1, Ordering::Relaxed);
302 }
303 }
304
305 pub fn remove_expired(&self) {
307 let ttl = Duration::from_secs(self.config.ttl_seconds);
308 self.entries.retain(|_, v| !v.is_expired(ttl));
309 }
310
311 pub fn invalidate_for_workspace(&self, workspace: Option<&str>) {
313 self.entries.retain(|_, v| {
314 let should_keep = v.filter_params.workspace.as_deref() != workspace;
315 if !should_keep {
316 self.stats.invalidations.fetch_add(1, Ordering::Relaxed);
317 }
318 should_keep
319 });
320 }
321
322 pub fn invalidate_for_memory(&self, memory_id: i64) {
324 self.entries.retain(|_, v| {
330 let contains_memory = v.results.iter().any(|r| r.memory.id == memory_id);
332 if contains_memory {
333 self.stats.invalidations.fetch_add(1, Ordering::Relaxed);
334 }
335 !contains_memory
336 });
337 }
338
339 pub fn clear(&self) {
341 let count = self.entries.len();
342 self.entries.clear();
343 self.stats
344 .invalidations
345 .fetch_add(count as u64, Ordering::Relaxed);
346 }
347
348 pub fn record_feedback(&self, query: &str, filters: &CacheFilterParams, positive: bool) {
350 let query_hash = Self::hash_query(query);
351 let cache_key = Self::cache_key(query_hash, filters);
352
353 if let Some(entry) = self.entries.get(&cache_key) {
354 entry.record_feedback(positive);
355 }
356
357 if self.config.adaptive_enabled {
359 self.adjust_threshold(positive);
360 }
361 }
362
363 fn adjust_threshold(&self, positive: bool) {
365 let current = self.current_threshold();
366 let adjustment = 0.01; let new_threshold = if positive {
369 (current - adjustment).max(self.config.min_threshold)
371 } else {
372 (current + adjustment).min(self.config.max_threshold)
374 };
375
376 self.current_threshold
377 .store(new_threshold.to_bits(), Ordering::Relaxed);
378 }
379
380 pub fn stats(&self) -> CacheStatsResponse {
382 CacheStatsResponse {
383 entries: self.entries.len(),
384 hits: self.stats.hits.load(Ordering::Relaxed),
385 misses: self.stats.misses.load(Ordering::Relaxed),
386 hit_rate: self.stats.hit_rate(),
387 invalidations: self.stats.invalidations.load(Ordering::Relaxed),
388 evictions: self.stats.evictions.load(Ordering::Relaxed),
389 current_threshold: self.current_threshold(),
390 ttl_seconds: self.config.ttl_seconds,
391 }
392 }
393
394 pub fn start_expiration_worker(cache: Arc<Self>, interval_secs: u64) {
396 std::thread::spawn(move || loop {
397 std::thread::sleep(Duration::from_secs(interval_secs));
398 cache.remove_expired();
399 });
400 }
401}
402
403#[cfg(test)]
404mod tests {
405 use super::*;
406 use crate::types::MemoryType;
407
408 fn make_test_memory(id: i64, content: &str) -> crate::types::Memory {
409 crate::types::Memory {
410 id,
411 content: content.to_string(),
412 memory_type: MemoryType::Note,
413 importance: 0.5,
414 tags: vec![],
415 access_count: 0,
416 created_at: chrono::Utc::now(),
417 updated_at: chrono::Utc::now(),
418 last_accessed_at: None,
419 owner_id: None,
420 visibility: Default::default(),
421 version: 1,
422 has_embedding: false,
423 metadata: Default::default(),
424 scope: crate::types::MemoryScope::Global,
425 workspace: "default".to_string(),
426 tier: crate::types::MemoryTier::Permanent,
427 expires_at: None,
428 content_hash: None,
429 event_time: None,
430 event_duration_seconds: None,
431 trigger_pattern: None,
432 procedure_success_count: 0,
433 procedure_failure_count: 0,
434 summary_of_id: None,
435 lifecycle_state: crate::types::LifecycleState::Active,
436 }
437 }
438
439 fn make_test_result(id: i64, content: &str, score: f32) -> SearchResult {
440 SearchResult {
441 memory: make_test_memory(id, content),
442 score,
443 match_info: crate::types::MatchInfo {
444 strategy: crate::types::SearchStrategy::Hybrid,
445 matched_terms: vec![],
446 highlights: vec![],
447 semantic_score: None,
448 keyword_score: Some(score),
449 },
450 }
451 }
452
453 #[test]
454 fn test_cache_put_get() {
455 let cache = SearchResultCache::new(AdaptiveCacheConfig::default());
456 let results = vec![make_test_result(1, "test content", 0.9)];
457
458 cache.put(
459 "test query",
460 None,
461 CacheFilterParams::default(),
462 results.clone(),
463 );
464
465 let cached = cache.get("test query", None, &CacheFilterParams::default());
466 assert!(cached.is_some());
467 assert_eq!(cached.unwrap().len(), 1);
468 }
469
470 #[test]
471 fn test_cache_miss() {
472 let cache = SearchResultCache::new(AdaptiveCacheConfig::default());
473
474 let cached = cache.get("nonexistent", None, &CacheFilterParams::default());
475 assert!(cached.is_none());
476 }
477
478 #[test]
479 fn test_cache_invalidation() {
480 let cache = SearchResultCache::new(AdaptiveCacheConfig::default());
481 let results = vec![make_test_result(1, "test", 0.9)];
482
483 cache.put("query", None, CacheFilterParams::default(), results);
484
485 assert!(cache
487 .get("query", None, &CacheFilterParams::default())
488 .is_some());
489
490 cache.invalidate_for_memory(1);
492
493 assert!(cache
495 .get("query", None, &CacheFilterParams::default())
496 .is_none());
497 }
498
499 #[test]
500 fn test_different_filters_different_cache() {
501 let cache = SearchResultCache::new(AdaptiveCacheConfig::default());
502 let results1 = vec![make_test_result(1, "result 1", 0.9)];
503 let results2 = vec![make_test_result(2, "result 2", 0.8)];
504
505 let filters1 = CacheFilterParams {
506 workspace: Some("ws1".to_string()),
507 ..Default::default()
508 };
509 let filters2 = CacheFilterParams {
510 workspace: Some("ws2".to_string()),
511 ..Default::default()
512 };
513
514 cache.put("query", None, filters1.clone(), results1);
515 cache.put("query", None, filters2.clone(), results2);
516
517 let cached1 = cache.get("query", None, &filters1);
518 let cached2 = cache.get("query", None, &filters2);
519
520 assert!(cached1.is_some());
521 assert!(cached2.is_some());
522 assert_eq!(cached1.unwrap()[0].memory.id, 1);
523 assert_eq!(cached2.unwrap()[0].memory.id, 2);
524 }
525
526 #[test]
527 fn test_similarity_lookup() {
528 let cache = SearchResultCache::new(AdaptiveCacheConfig {
529 similarity_threshold: 0.9,
530 ..Default::default()
531 });
532
533 let embedding = vec![1.0, 0.0, 0.0];
534 let results = vec![make_test_result(1, "test", 0.9)];
535
536 cache.put(
537 "original query",
538 Some(embedding.clone()),
539 CacheFilterParams::default(),
540 results,
541 );
542
543 let cached = cache.get(
545 "different query",
546 Some(&embedding),
547 &CacheFilterParams::default(),
548 );
549 assert!(cached.is_some());
550
551 let similar = vec![0.99, 0.1, 0.0];
553 let cached = cache.get(
554 "another query",
555 Some(&similar),
556 &CacheFilterParams::default(),
557 );
558 assert!(cached.is_some());
559
560 let different = vec![0.0, 1.0, 0.0];
562 let cached = cache.get(
563 "yet another",
564 Some(&different),
565 &CacheFilterParams::default(),
566 );
567 assert!(cached.is_none());
568 }
569
570 #[test]
571 fn test_stats() {
572 let cache = SearchResultCache::new(AdaptiveCacheConfig::default());
573 let results = vec![make_test_result(1, "test", 0.9)];
574
575 cache.get("query", None, &CacheFilterParams::default());
577
578 cache.put("query", None, CacheFilterParams::default(), results);
580
581 cache.get("query", None, &CacheFilterParams::default());
583 cache.get("query", None, &CacheFilterParams::default());
584
585 let stats = cache.stats();
586 assert_eq!(stats.entries, 1);
587 assert_eq!(stats.misses, 1);
588 assert_eq!(stats.hits, 2);
589 assert!(stats.hit_rate > 0.6);
590 }
591}