Skip to main content

nexus_memory_agent/
cognitive_cache.rs

1//! Cognitive cache data models for tiering and ranking.
2
3use chrono::{DateTime, Utc};
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::path::Path;
7use std::time::Duration;
8use tracing::{debug, warn};
9
10use crate::context_builder::ColdRecall;
11use crate::error::AgentError;
12use nexus_core::{EmbeddingService, Memory, ProjectIdentity};
13use nexus_storage::repository::MemoryRepository;
14use nexus_vectors::{SearchOptions, SemanticSearch, VectorEntry};
15
16/// Confidence tier for memory surfacing.
17#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
18pub enum ConfidenceTier {
19    /// Score < 0.72 - One-liner, model decides
20    Whisper,
21    /// Score >= 0.72 - Present, lightly compressed
22    Clear,
23    /// Score >= 0.85 - Full content, direct injection
24    Loud,
25}
26
27impl ConfidenceTier {
28    /// Determine confidence tier from a raw relevance score.
29    pub fn from_score(score: f32) -> Self {
30        if score >= 0.85 {
31            ConfidenceTier::Loud
32        } else if score >= 0.72 {
33            ConfidenceTier::Clear
34        } else {
35            ConfidenceTier::Whisper
36        }
37    }
38}
39
40/// Entry in the hot cognitive cache.
41#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct HotCacheEntry {
43    pub memory_id: i64,
44    pub content: String,
45    pub relevance_score: f32,
46    pub tier: ConfidenceTier,
47    pub promoted_at: DateTime<Utc>,
48    pub last_surfaced: DateTime<Utc>,
49    pub hot_streak: u32,
50    pub pinned: bool,
51    pub source_agent: Option<String>,
52}
53
54impl HotCacheEntry {
55    /// Calculate a composite score for eviction.
56    /// Combines relevance, hot streak (frequency), and recency (LRU).
57    pub fn eviction_score(&self) -> f32 {
58        if self.pinned {
59            return f32::MAX;
60        }
61
62        let now = Utc::now();
63        let age_secs = now
64            .signed_duration_since(self.last_surfaced)
65            .num_seconds()
66            .max(1) as f32;
67
68        // Decay factor: items not surfaced for 24h lose significant score
69        let age_days = (age_secs / 86400.0).min(80.0);
70        let recency_penalty = age_days.exp();
71
72        // Boost for repeated use (frequency)
73        let frequency_boost = (self.hot_streak as f32).ln().max(1.0);
74
75        (self.relevance_score * frequency_boost) / recency_penalty
76    }
77}
78
79/// Hot cognitive cache - holds the most active context for a project.
80#[derive(Debug, Clone, Serialize, Deserialize, Default)]
81pub struct HotCache {
82    pub entries: Vec<HotCacheEntry>,
83    pub last_updated: Option<DateTime<Utc>>,
84    pub last_session_id: Option<String>,
85}
86
87impl HotCache {
88    /// Promote a new entry to the hot cache.
89    pub fn promote(&mut self, entry: HotCacheEntry, max_entries: usize) -> bool {
90        if let Some(existing) = self
91            .entries
92            .iter_mut()
93            .find(|e| e.memory_id == entry.memory_id)
94        {
95            existing.content = entry.content;
96            existing.relevance_score = entry.relevance_score;
97            existing.tier = entry.tier;
98            existing.hot_streak += 1;
99            existing.last_surfaced = Utc::now();
100            existing.pinned = existing.pinned || entry.pinned; // preserve existing pin
101            return true;
102        }
103
104        if self.entries.len() >= max_entries {
105            // PHASE 10: LRU-Aware Eviction
106            let mut candidates: Vec<(usize, f32)> = self
107                .entries
108                .iter()
109                .enumerate()
110                .filter(|(_, e)| !e.pinned)
111                .map(|(i, e)| (i, e.eviction_score()))
112                .collect();
113
114            if !candidates.is_empty() {
115                // Sort by composite eviction score ascending (lowest score gets evicted)
116                candidates
117                    .sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
118                self.entries.remove(candidates[0].0);
119            } else {
120                // All existing entries are pinned; do not exceed capacity.
121                return false;
122            }
123        }
124
125        self.entries.push(entry);
126        self.last_updated = Some(Utc::now());
127        true
128    }
129}
130
131/// Entry in the cold cache index.
132#[derive(Debug, Clone, Serialize, Deserialize)]
133pub struct ColdIndexEntry {
134    pub memory_id: i64,
135    pub project_relevance: f32,
136    pub last_surfaced: Option<DateTime<Utc>>,
137}
138
139/// Cold cache index - references memories that might be relevant later.
140#[derive(Debug, Clone, Serialize, Deserialize, Default)]
141pub struct ColdCacheIndex {
142    pub entries: Vec<ColdIndexEntry>,
143    pub last_reindexed: Option<DateTime<Utc>>,
144}
145
146/// Unified cognitive cache for a project.
147#[derive(Debug, Clone, Serialize, Deserialize, Default)]
148pub struct CognitiveCache {
149    pub hot_cache: HotCache,
150    pub cold_index: ColdCacheIndex,
151}
152
153impl CognitiveCache {
154    /// Perform morning recall to surface project-relevant memories.
155    pub async fn morning_recall(
156        &self,
157        project: &ProjectIdentity,
158        namespace_id: i64,
159        memory_repo: &MemoryRepository,
160        embedding_service: Option<&dyn EmbeddingService>,
161    ) -> Vec<ColdRecall> {
162        let _start = std::time::Instant::now();
163        let query_string = format!(
164            "{} {} project context",
165            project.display_name,
166            project.git_remote.as_deref().unwrap_or("")
167        );
168        let hot_ids: std::collections::HashSet<i64> =
169            self.hot_cache.entries.iter().map(|e| e.memory_id).collect();
170
171        let mut results = Vec::new();
172
173        if let Some(service) = embedding_service {
174            match tokio::time::timeout(Duration::from_millis(2000), async {
175                if let Ok(embedding) = service.embed(&query_string).await {
176                    // Fetch recent memories for candidate matching
177                    let filters = nexus_storage::repository::ListMemoryFilters {
178                        category: None,
179                        since: None,
180                        until: None,
181                        content_like: None,
182                        include_raw: false,
183                        limit: 50,
184                        offset: 0,
185                    };
186
187                    if let Ok(memories) = memory_repo.list_filtered(namespace_id, filters).await {
188                        let entries: Vec<VectorEntry> = memories
189                            .into_iter()
190                            .filter_map(|m| {
191                                m.content_embedding.as_ref().map(|emb| {
192                                    VectorEntry::new(
193                                        m.id,
194                                        emb.clone(),
195                                        m.category.to_string(),
196                                        namespace_id,
197                                    )
198                                })
199                            })
200                            .collect();
201
202                        let search = SemanticSearch::new();
203                        let options = SearchOptions::with_limit(20).with_threshold(0.65);
204
205                        if let Ok((search_results, _)) =
206                            search.search(&embedding, &entries, &options)
207                        {
208                            // Batch-fetch content for matches (SearchResult doesn't hold content directly)
209                            let filtered_results: Vec<_> = search_results
210                                .into_iter()
211                                .filter(|r| !hot_ids.contains(&r.id))
212                                .take(10)
213                                .collect();
214
215                            let ids: Vec<i64> = filtered_results.iter().map(|r| r.id).collect();
216
217                            let memories = match memory_repo.get_by_ids(&ids).await {
218                                Ok(m) => m,
219                                Err(e) => {
220                                    tracing::warn!("get_by_ids failed in morning_recall: {}", e);
221                                    Vec::new()
222                                }
223                            };
224
225                            // Preserve ordering from search_results by mapping id→memory
226                            let memory_by_id: HashMap<i64, Memory> =
227                                memories.into_iter().map(|m| (m.id, m)).collect();
228
229                            let mut recalls = Vec::new();
230                            for r in filtered_results {
231                                if let Some(m) = memory_by_id.get(&r.id) {
232                                    recalls.push(ColdRecall {
233                                        memory_id: r.id,
234                                        content: m.content.clone(),
235                                        relevance_score: r.score,
236                                        tier: ConfidenceTier::from_score(r.score),
237                                    });
238                                }
239                            }
240                            return Ok::<Vec<ColdRecall>, AgentError>(recalls);
241                        }
242                    }
243                }
244                Ok(Vec::new())
245            })
246            .await
247            {
248                Ok(Ok(recalls)) => results = recalls,
249                Ok(Err(e)) => warn!("Morning recall vector search failed: {}", e),
250                Err(_) => warn!("Morning recall vector search timed out"),
251            }
252        }
253
254        if results.is_empty() {
255            let filters = nexus_storage::repository::ListMemoryFilters {
256                category: None,
257                since: None,
258                until: None,
259                content_like: Some(&project.display_name),
260                include_raw: false,
261                limit: 10,
262                offset: 0,
263            };
264
265            if let Ok(memories) = memory_repo.list_filtered(namespace_id, filters).await {
266                results = memories
267                    .into_iter()
268                    .filter(|m| !hot_ids.contains(&m.id))
269                    .take(10)
270                    .map(|m| ColdRecall {
271                        memory_id: m.id,
272                        content: m.content,
273                        relevance_score: 0.65,
274                        tier: ConfidenceTier::Whisper,
275                    })
276                    .collect();
277            }
278
279            // Also include cold_index entries if no results from fallback
280            if results.is_empty() {
281                // Sort cold index by relevance descending before taking top entries
282                let mut sorted_cold: Vec<_> = self
283                    .cold_index
284                    .entries
285                    .iter()
286                    .filter(|e| !hot_ids.contains(&e.memory_id) && e.project_relevance >= 0.3)
287                    .collect();
288                sorted_cold.sort_by(|a, b| {
289                    b.project_relevance
290                        .partial_cmp(&a.project_relevance)
291                        .unwrap_or(std::cmp::Ordering::Equal)
292                });
293                let cold_ids: Vec<i64> = sorted_cold.iter().take(10).map(|e| e.memory_id).collect();
294
295                if !cold_ids.is_empty() {
296                    match memory_repo.get_by_ids(&cold_ids).await {
297                        Ok(cold_memories) => {
298                            let cold_memory_by_id: HashMap<i64, Memory> =
299                                cold_memories.into_iter().map(|m| (m.id, m)).collect();
300
301                            for cold_entry in sorted_cold.iter().take(10) {
302                                if let Some(m) = cold_memory_by_id.get(&cold_entry.memory_id) {
303                                    results.push(ColdRecall {
304                                        memory_id: m.id,
305                                        content: m.content.clone(),
306                                        relevance_score: cold_entry.project_relevance,
307                                        tier: ConfidenceTier::from_score(
308                                            cold_entry.project_relevance,
309                                        ),
310                                    });
311                                }
312                            }
313                        }
314                        Err(e) => {
315                            debug!("get_by_ids failed for cold_index in morning_recall: {}", e);
316                        }
317                    }
318                }
319            }
320        }
321
322        debug!(
323            "Morning recall found {} items in {:?}",
324            results.len(),
325            _start.elapsed()
326        );
327        results
328    }
329
330    /// Load cognitive cache from disk or initialize if missing.
331    pub fn load_or_init(nexus_dir: &Path) -> Self {
332        let cache_dir = nexus_dir.join("cache");
333        let hot_path = cache_dir.join("hot.json");
334        let cold_path = cache_dir.join("cold_index.json");
335
336        let hot_cache = if hot_path.exists() {
337            match std::fs::read_to_string(&hot_path) {
338                Ok(s) => match serde_json::from_str(&s) {
339                    Ok(cache) => cache,
340                    Err(e) => {
341                        tracing::warn!(
342                            path = %hot_path.display(),
343                            error = %e,
344                            "Failed to parse hot cache; using defaults"
345                        );
346                        HotCache::default()
347                    }
348                },
349                Err(e) => {
350                    tracing::warn!(
351                        path = %hot_path.display(),
352                        error = %e,
353                        "Failed to read hot cache; using defaults"
354                    );
355                    HotCache::default()
356                }
357            }
358        } else {
359            HotCache::default()
360        };
361
362        let cold_index = if cold_path.exists() {
363            match std::fs::read_to_string(&cold_path) {
364                Ok(s) => match serde_json::from_str(&s) {
365                    Ok(idx) => idx,
366                    Err(e) => {
367                        tracing::warn!(
368                            path = %cold_path.display(),
369                            error = %e,
370                            "Failed to parse cold index; using defaults"
371                        );
372                        ColdCacheIndex::default()
373                    }
374                },
375                Err(e) => {
376                    tracing::warn!(
377                        path = %cold_path.display(),
378                        error = %e,
379                        "Failed to read cold index; using defaults"
380                    );
381                    ColdCacheIndex::default()
382                }
383            }
384        } else {
385            ColdCacheIndex::default()
386        };
387
388        Self {
389            hot_cache,
390            cold_index,
391        }
392    }
393
394    /// Save cognitive cache to disk atomically.
395    pub fn save(&self, nexus_dir: &Path) -> std::io::Result<()> {
396        let cache_dir = nexus_dir.join("cache");
397        std::fs::create_dir_all(&cache_dir)?;
398
399        let hot_json = serde_json::to_string_pretty(&self.hot_cache)?;
400        nexus_core::fsutil::atomic_write(&cache_dir.join("hot.json"), &hot_json)?;
401
402        let cold_json = serde_json::to_string_pretty(&self.cold_index)?;
403        nexus_core::fsutil::atomic_write(&cache_dir.join("cold_index.json"), &cold_json)?;
404
405        Ok(())
406    }
407}
408
409#[cfg(test)]
410mod tests {
411    use super::*;
412    use tempfile::tempdir;
413
414    #[test]
415    fn test_confidence_tier_boundaries() {
416        assert_eq!(ConfidenceTier::from_score(0.85), ConfidenceTier::Loud);
417        assert_eq!(ConfidenceTier::from_score(0.84), ConfidenceTier::Clear);
418        assert_eq!(ConfidenceTier::from_score(0.72), ConfidenceTier::Clear);
419        assert_eq!(ConfidenceTier::from_score(0.71), ConfidenceTier::Whisper);
420        assert_eq!(ConfidenceTier::from_score(0.50), ConfidenceTier::Whisper);
421    }
422
423    #[test]
424    fn test_hot_cache_promote_and_evict() {
425        let mut hot = HotCache::default();
426        let max = 2;
427
428        let e1 = HotCacheEntry {
429            memory_id: 1,
430            content: "e1".into(),
431            relevance_score: 0.9,
432            tier: ConfidenceTier::Loud,
433            promoted_at: Utc::now(),
434            last_surfaced: Utc::now(),
435            hot_streak: 1,
436            pinned: false,
437            source_agent: None,
438        };
439        let e2 = HotCacheEntry {
440            memory_id: 2,
441            content: "e2".into(),
442            relevance_score: 0.8,
443            tier: ConfidenceTier::Clear,
444            promoted_at: Utc::now(),
445            last_surfaced: Utc::now(),
446            hot_streak: 1,
447            pinned: false,
448            source_agent: None,
449        };
450        let e3 = HotCacheEntry {
451            memory_id: 3,
452            content: "e3".into(),
453            relevance_score: 0.95,
454            tier: ConfidenceTier::Loud,
455            promoted_at: Utc::now(),
456            last_surfaced: Utc::now(),
457            hot_streak: 1,
458            pinned: false,
459            source_agent: None,
460        };
461
462        hot.promote(e1, max);
463        hot.promote(e2, max);
464        assert_eq!(hot.entries.len(), 2);
465
466        hot.promote(e3, max);
467        assert_eq!(hot.entries.len(), 2);
468        // e2 should be evicted as it has the lowest relevance/eviction score
469        assert!(hot.entries.iter().any(|e| e.memory_id == 1));
470        assert!(hot.entries.iter().any(|e| e.memory_id == 3));
471    }
472
473    #[test]
474    fn test_hot_cache_never_evicts_pinned() {
475        let mut hot = HotCache::default();
476        let max = 1;
477
478        let pinned = HotCacheEntry {
479            memory_id: 1,
480            content: "pinned".into(),
481            relevance_score: 0.1,
482            tier: ConfidenceTier::Whisper,
483            promoted_at: Utc::now(),
484            last_surfaced: Utc::now(),
485            hot_streak: 1,
486            pinned: true,
487            source_agent: None,
488        };
489        let high = HotCacheEntry {
490            memory_id: 2,
491            content: "high".into(),
492            relevance_score: 0.99,
493            tier: ConfidenceTier::Loud,
494            promoted_at: Utc::now(),
495            last_surfaced: Utc::now(),
496            hot_streak: 1,
497            pinned: false,
498            source_agent: None,
499        };
500
501        hot.promote(pinned, max);
502        hot.promote(high, max);
503
504        assert_eq!(hot.entries.len(), 1);
505        assert_eq!(hot.entries[0].memory_id, 1);
506    }
507
508    #[test]
509    fn test_cache_persistence_roundtrip() {
510        let dir = tempdir().unwrap();
511        let nexus_dir = dir.path();
512
513        let mut cache = CognitiveCache::default();
514        cache.hot_cache.entries.push(HotCacheEntry {
515            memory_id: 1,
516            content: "test".into(),
517            relevance_score: 0.9,
518            tier: ConfidenceTier::Loud,
519            promoted_at: Utc::now(),
520            last_surfaced: Utc::now(),
521            hot_streak: 1,
522            pinned: false,
523            source_agent: None,
524        });
525
526        cache.save(nexus_dir).unwrap();
527        let loaded = CognitiveCache::load_or_init(nexus_dir);
528
529        assert_eq!(loaded.hot_cache.entries.len(), 1);
530        assert_eq!(loaded.hot_cache.entries[0].content, "test");
531    }
532
533    #[test]
534    fn test_load_or_init_handles_missing_and_corrupt() {
535        let dir = tempdir().unwrap();
536        let nexus_dir = dir.path();
537
538        // Missing
539        let cache = CognitiveCache::load_or_init(nexus_dir);
540        assert_eq!(cache.hot_cache.entries.len(), 0);
541
542        // Corrupt
543        let cache_dir = nexus_dir.join("cache");
544        std::fs::create_dir_all(&cache_dir).unwrap();
545        std::fs::write(cache_dir.join("hot.json"), "invalid json").unwrap();
546
547        let cache = CognitiveCache::load_or_init(nexus_dir);
548        assert_eq!(cache.hot_cache.entries.len(), 0);
549    }
550}