Skip to main content

nexus_memory_agent/
cognitive_cache.rs

1//! Cognitive cache data models for tiering and ranking.
2
3use chrono::{DateTime, Utc};
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::path::Path;
7use std::time::Duration;
8use tracing::{debug, warn};
9
10use crate::context_builder::ColdRecall;
11use crate::error::AgentError;
12use nexus_core::{EmbeddingService, Memory, ProjectIdentity};
13use nexus_storage::repository::MemoryRepository;
14use nexus_vectors::{SearchOptions, SemanticSearch, VectorEntry};
15
16/// Confidence tier for memory surfacing.
17#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
18pub enum ConfidenceTier {
19    /// Score < 0.72 - One-liner, model decides
20    Whisper,
21    /// Score >= 0.72 - Present, lightly compressed
22    Clear,
23    /// Score >= 0.85 - Full content, direct injection
24    Loud,
25}
26
27impl ConfidenceTier {
28    /// Determine confidence tier from a raw relevance score.
29    pub fn from_score(score: f32) -> Self {
30        if score >= 0.85 {
31            ConfidenceTier::Loud
32        } else if score >= 0.72 {
33            ConfidenceTier::Clear
34        } else {
35            ConfidenceTier::Whisper
36        }
37    }
38}
39
40/// Entry in the hot cognitive cache.
41#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct HotCacheEntry {
43    pub memory_id: i64,
44    pub content: String,
45    pub relevance_score: f32,
46    pub tier: ConfidenceTier,
47    pub promoted_at: DateTime<Utc>,
48    pub last_surfaced: DateTime<Utc>,
49    pub hot_streak: u32,
50    pub pinned: bool,
51    pub source_agent: Option<String>,
52}
53
54impl HotCacheEntry {
55    /// Calculate a composite score for eviction.
56    /// Combines relevance, hot streak (frequency), and recency (LRU).
57    pub fn eviction_score(&self) -> f32 {
58        if self.pinned {
59            return f32::MAX;
60        }
61
62        let now = Utc::now();
63        let age_secs = now
64            .signed_duration_since(self.last_surfaced)
65            .num_seconds()
66            .max(1) as f32;
67
68        // Decay factor: items not surfaced for 24h lose significant score
69        let age_days = (age_secs / 86400.0).min(80.0);
70        let recency_penalty = age_days.exp();
71
72        // Boost for repeated use (frequency)
73        let frequency_boost = (self.hot_streak as f32).ln().max(1.0);
74
75        (self.relevance_score * frequency_boost) / recency_penalty
76    }
77}
78
79/// Hot cognitive cache - holds the most active context for a project.
80#[derive(Debug, Clone, Serialize, Deserialize, Default)]
81pub struct HotCache {
82    pub entries: Vec<HotCacheEntry>,
83    pub last_updated: Option<DateTime<Utc>>,
84    pub last_session_id: Option<String>,
85}
86
87impl HotCache {
88    /// Promote a new entry to the hot cache.
89    pub fn promote(&mut self, entry: HotCacheEntry, max_entries: usize) -> bool {
90        if let Some(existing) = self
91            .entries
92            .iter_mut()
93            .find(|e| e.memory_id == entry.memory_id)
94        {
95            existing.content = entry.content;
96            existing.relevance_score = entry.relevance_score;
97            existing.tier = entry.tier;
98            existing.hot_streak += 1;
99            existing.last_surfaced = Utc::now();
100            existing.pinned = existing.pinned || entry.pinned; // preserve existing pin
101            return true;
102        }
103
104        if self.entries.len() >= max_entries {
105            // PHASE 10: LRU-Aware Eviction
106            let mut candidates: Vec<(usize, f32)> = self
107                .entries
108                .iter()
109                .enumerate()
110                .filter(|(_, e)| !e.pinned)
111                .map(|(i, e)| (i, e.eviction_score()))
112                .collect();
113
114            if !candidates.is_empty() {
115                // Sort by composite eviction score ascending (lowest score gets evicted)
116                candidates
117                    .sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
118                self.entries.remove(candidates[0].0);
119            } else {
120                // All existing entries are pinned; do not exceed capacity.
121                return false;
122            }
123        }
124
125        self.entries.push(entry);
126        self.last_updated = Some(Utc::now());
127        true
128    }
129}
130
131/// Entry in the cold cache index.
132#[derive(Debug, Clone, Serialize, Deserialize)]
133pub struct ColdIndexEntry {
134    pub memory_id: i64,
135    pub project_relevance: f32,
136    pub last_surfaced: Option<DateTime<Utc>>,
137}
138
139/// Cold cache index - references memories that might be relevant later.
140#[derive(Debug, Clone, Serialize, Deserialize, Default)]
141pub struct ColdCacheIndex {
142    pub entries: Vec<ColdIndexEntry>,
143    pub last_reindexed: Option<DateTime<Utc>>,
144}
145
146/// Unified cognitive cache for a project.
147#[derive(Debug, Clone, Serialize, Deserialize, Default)]
148pub struct CognitiveCache {
149    pub hot_cache: HotCache,
150    pub cold_index: ColdCacheIndex,
151}
152
153impl CognitiveCache {
154    /// Check if a memory is a system-generated memory that should be excluded from context.
155    fn is_system_memory(memory: &Memory) -> bool {
156        // Check metadata for session_lifecycle or runtime markers
157        if let Some(obj) = memory.metadata.as_object() {
158            if obj.get("session_lifecycle").is_some() || obj.get("runtime").is_some() {
159                return true;
160            }
161        }
162        // Also check labels for system tags
163        if memory
164            .labels
165            .iter()
166            .any(|l| l == "session" || l == "runtime")
167        {
168            return true;
169        }
170        false
171    }
172
173    /// Perform morning recall to surface project-relevant memories.
174    pub async fn morning_recall(
175        &self,
176        project: &ProjectIdentity,
177        namespace_id: i64,
178        memory_repo: &MemoryRepository,
179        embedding_service: Option<&dyn EmbeddingService>,
180    ) -> Vec<ColdRecall> {
181        let _start = std::time::Instant::now();
182        let query_string = format!(
183            "{} {} project context",
184            project.display_name,
185            project.git_remote.as_deref().unwrap_or("")
186        );
187        let hot_ids: std::collections::HashSet<i64> =
188            self.hot_cache.entries.iter().map(|e| e.memory_id).collect();
189
190        let mut results = Vec::new();
191
192        if let Some(service) = embedding_service {
193            match tokio::time::timeout(Duration::from_millis(2000), async {
194                if let Ok(embedding) = service.embed(&query_string).await {
195                    // Fetch recent memories for candidate matching
196                    let filters = nexus_storage::repository::ListMemoryFilters {
197                        category: None,
198                        since: None,
199                        until: None,
200                        content_like: None,
201                        include_raw: false,
202                        limit: 50,
203                        offset: 0,
204                    };
205
206                    if let Ok(memories) = memory_repo.list_filtered(namespace_id, filters).await {
207                        let entries: Vec<VectorEntry> = memories
208                            .into_iter()
209                            .filter_map(|m| {
210                                m.content_embedding.as_ref().map(|emb| {
211                                    VectorEntry::new(
212                                        m.id,
213                                        emb.clone(),
214                                        m.category.to_string(),
215                                        namespace_id,
216                                    )
217                                })
218                            })
219                            .collect();
220
221                        let search = SemanticSearch::new();
222                        let options = SearchOptions::with_limit(20).with_threshold(0.65);
223
224                        if let Ok((search_results, _)) =
225                            search.search(&embedding, &entries, &options)
226                        {
227                            // Batch-fetch content for matches (SearchResult doesn't hold content directly)
228                            let filtered_results: Vec<_> = search_results
229                                .into_iter()
230                                .filter(|r| !hot_ids.contains(&r.id))
231                                .take(10)
232                                .collect();
233
234                            let ids: Vec<i64> = filtered_results.iter().map(|r| r.id).collect();
235
236                            let memories = match memory_repo.get_by_ids(&ids).await {
237                                Ok(m) => m,
238                                Err(e) => {
239                                    tracing::warn!("get_by_ids failed in morning_recall: {}", e);
240                                    Vec::new()
241                                }
242                            };
243
244                            // Preserve ordering from search_results by mapping id→memory
245                            let memory_by_id: HashMap<i64, Memory> =
246                                memories.into_iter().map(|m| (m.id, m)).collect();
247
248                            let mut recalls = Vec::new();
249                            for r in filtered_results {
250                                if let Some(m) = memory_by_id.get(&r.id) {
251                                    // Skip system-generated memories (session lifecycle, runtime markers)
252                                    if Self::is_system_memory(m) {
253                                        continue;
254                                    }
255                                    recalls.push(ColdRecall {
256                                        memory_id: r.id,
257                                        content: m.content.clone(),
258                                        relevance_score: r.score,
259                                        tier: ConfidenceTier::from_score(r.score),
260                                    });
261                                }
262                            }
263                            return Ok::<Vec<ColdRecall>, AgentError>(recalls);
264                        }
265                    }
266                }
267                Ok(Vec::new())
268            })
269            .await
270            {
271                Ok(Ok(recalls)) => results = recalls,
272                Ok(Err(e)) => warn!("Morning recall vector search failed: {}", e),
273                Err(_) => warn!("Morning recall vector search timed out"),
274            }
275        }
276
277        if results.is_empty() {
278            let filters = nexus_storage::repository::ListMemoryFilters {
279                category: None,
280                since: None,
281                until: None,
282                content_like: Some(&project.display_name),
283                include_raw: false,
284                limit: 10,
285                offset: 0,
286            };
287
288            if let Ok(memories) = memory_repo.list_filtered(namespace_id, filters).await {
289                results = memories
290                    .into_iter()
291                    .filter(|m| !hot_ids.contains(&m.id) && !Self::is_system_memory(m))
292                    .take(10)
293                    .map(|m| ColdRecall {
294                        memory_id: m.id,
295                        content: m.content,
296                        relevance_score: 0.65,
297                        tier: ConfidenceTier::Whisper,
298                    })
299                    .collect();
300            }
301
302            // Also include cold_index entries if no results from fallback
303            if results.is_empty() {
304                // Sort cold index by relevance descending before taking top entries
305                let mut sorted_cold: Vec<_> = self
306                    .cold_index
307                    .entries
308                    .iter()
309                    .filter(|e| !hot_ids.contains(&e.memory_id) && e.project_relevance >= 0.3)
310                    .collect();
311                sorted_cold.sort_by(|a, b| {
312                    b.project_relevance
313                        .partial_cmp(&a.project_relevance)
314                        .unwrap_or(std::cmp::Ordering::Equal)
315                });
316                let cold_ids: Vec<i64> = sorted_cold.iter().take(10).map(|e| e.memory_id).collect();
317
318                if !cold_ids.is_empty() {
319                    match memory_repo.get_by_ids(&cold_ids).await {
320                        Ok(cold_memories) => {
321                            let cold_memory_by_id: HashMap<i64, Memory> =
322                                cold_memories.into_iter().map(|m| (m.id, m)).collect();
323
324                            for cold_entry in sorted_cold.iter().take(10) {
325                                if let Some(m) = cold_memory_by_id.get(&cold_entry.memory_id) {
326                                    // Skip system-generated memories
327                                    if Self::is_system_memory(m) {
328                                        continue;
329                                    }
330                                    results.push(ColdRecall {
331                                        memory_id: m.id,
332                                        content: m.content.clone(),
333                                        relevance_score: cold_entry.project_relevance,
334                                        tier: ConfidenceTier::from_score(
335                                            cold_entry.project_relevance,
336                                        ),
337                                    });
338                                }
339                            }
340                        }
341                        Err(e) => {
342                            debug!("get_by_ids failed for cold_index in morning_recall: {}", e);
343                        }
344                    }
345                }
346            }
347        }
348
349        debug!(
350            "Morning recall found {} items in {:?}",
351            results.len(),
352            _start.elapsed()
353        );
354        results
355    }
356
357    /// Load cognitive cache from disk or initialize if missing.
358    pub fn load_or_init(nexus_dir: &Path) -> Self {
359        let cache_dir = nexus_dir.join("cache");
360        let hot_path = cache_dir.join("hot.json");
361        let cold_path = cache_dir.join("cold_index.json");
362
363        let hot_cache = if hot_path.exists() {
364            match std::fs::read_to_string(&hot_path) {
365                Ok(s) => match serde_json::from_str(&s) {
366                    Ok(cache) => cache,
367                    Err(e) => {
368                        tracing::warn!(
369                            path = %hot_path.display(),
370                            error = %e,
371                            "Failed to parse hot cache; using defaults"
372                        );
373                        HotCache::default()
374                    }
375                },
376                Err(e) => {
377                    tracing::warn!(
378                        path = %hot_path.display(),
379                        error = %e,
380                        "Failed to read hot cache; using defaults"
381                    );
382                    HotCache::default()
383                }
384            }
385        } else {
386            HotCache::default()
387        };
388
389        let cold_index = if cold_path.exists() {
390            match std::fs::read_to_string(&cold_path) {
391                Ok(s) => match serde_json::from_str(&s) {
392                    Ok(idx) => idx,
393                    Err(e) => {
394                        tracing::warn!(
395                            path = %cold_path.display(),
396                            error = %e,
397                            "Failed to parse cold index; using defaults"
398                        );
399                        ColdCacheIndex::default()
400                    }
401                },
402                Err(e) => {
403                    tracing::warn!(
404                        path = %cold_path.display(),
405                        error = %e,
406                        "Failed to read cold index; using defaults"
407                    );
408                    ColdCacheIndex::default()
409                }
410            }
411        } else {
412            ColdCacheIndex::default()
413        };
414
415        Self {
416            hot_cache,
417            cold_index,
418        }
419    }
420
421    /// Save cognitive cache to disk atomically.
422    pub fn save(&self, nexus_dir: &Path) -> std::io::Result<()> {
423        let cache_dir = nexus_dir.join("cache");
424        std::fs::create_dir_all(&cache_dir)?;
425
426        let hot_json = serde_json::to_string_pretty(&self.hot_cache)?;
427        nexus_core::fsutil::atomic_write(&cache_dir.join("hot.json"), &hot_json)?;
428
429        let cold_json = serde_json::to_string_pretty(&self.cold_index)?;
430        nexus_core::fsutil::atomic_write(&cache_dir.join("cold_index.json"), &cold_json)?;
431
432        Ok(())
433    }
434}
435
436#[cfg(test)]
437mod tests {
438    use super::*;
439    use tempfile::tempdir;
440
441    #[test]
442    fn test_confidence_tier_boundaries() {
443        assert_eq!(ConfidenceTier::from_score(0.85), ConfidenceTier::Loud);
444        assert_eq!(ConfidenceTier::from_score(0.84), ConfidenceTier::Clear);
445        assert_eq!(ConfidenceTier::from_score(0.72), ConfidenceTier::Clear);
446        assert_eq!(ConfidenceTier::from_score(0.71), ConfidenceTier::Whisper);
447        assert_eq!(ConfidenceTier::from_score(0.50), ConfidenceTier::Whisper);
448    }
449
450    #[test]
451    fn test_hot_cache_promote_and_evict() {
452        let mut hot = HotCache::default();
453        let max = 2;
454
455        let e1 = HotCacheEntry {
456            memory_id: 1,
457            content: "e1".into(),
458            relevance_score: 0.9,
459            tier: ConfidenceTier::Loud,
460            promoted_at: Utc::now(),
461            last_surfaced: Utc::now(),
462            hot_streak: 1,
463            pinned: false,
464            source_agent: None,
465        };
466        let e2 = HotCacheEntry {
467            memory_id: 2,
468            content: "e2".into(),
469            relevance_score: 0.8,
470            tier: ConfidenceTier::Clear,
471            promoted_at: Utc::now(),
472            last_surfaced: Utc::now(),
473            hot_streak: 1,
474            pinned: false,
475            source_agent: None,
476        };
477        let e3 = HotCacheEntry {
478            memory_id: 3,
479            content: "e3".into(),
480            relevance_score: 0.95,
481            tier: ConfidenceTier::Loud,
482            promoted_at: Utc::now(),
483            last_surfaced: Utc::now(),
484            hot_streak: 1,
485            pinned: false,
486            source_agent: None,
487        };
488
489        hot.promote(e1, max);
490        hot.promote(e2, max);
491        assert_eq!(hot.entries.len(), 2);
492
493        hot.promote(e3, max);
494        assert_eq!(hot.entries.len(), 2);
495        // e2 should be evicted as it has the lowest relevance/eviction score
496        assert!(hot.entries.iter().any(|e| e.memory_id == 1));
497        assert!(hot.entries.iter().any(|e| e.memory_id == 3));
498    }
499
500    #[test]
501    fn test_hot_cache_never_evicts_pinned() {
502        let mut hot = HotCache::default();
503        let max = 1;
504
505        let pinned = HotCacheEntry {
506            memory_id: 1,
507            content: "pinned".into(),
508            relevance_score: 0.1,
509            tier: ConfidenceTier::Whisper,
510            promoted_at: Utc::now(),
511            last_surfaced: Utc::now(),
512            hot_streak: 1,
513            pinned: true,
514            source_agent: None,
515        };
516        let high = HotCacheEntry {
517            memory_id: 2,
518            content: "high".into(),
519            relevance_score: 0.99,
520            tier: ConfidenceTier::Loud,
521            promoted_at: Utc::now(),
522            last_surfaced: Utc::now(),
523            hot_streak: 1,
524            pinned: false,
525            source_agent: None,
526        };
527
528        hot.promote(pinned, max);
529        hot.promote(high, max);
530
531        assert_eq!(hot.entries.len(), 1);
532        assert_eq!(hot.entries[0].memory_id, 1);
533    }
534
535    #[test]
536    fn test_cache_persistence_roundtrip() {
537        let dir = tempdir().unwrap();
538        let nexus_dir = dir.path();
539
540        let mut cache = CognitiveCache::default();
541        cache.hot_cache.entries.push(HotCacheEntry {
542            memory_id: 1,
543            content: "test".into(),
544            relevance_score: 0.9,
545            tier: ConfidenceTier::Loud,
546            promoted_at: Utc::now(),
547            last_surfaced: Utc::now(),
548            hot_streak: 1,
549            pinned: false,
550            source_agent: None,
551        });
552
553        cache.save(nexus_dir).unwrap();
554        let loaded = CognitiveCache::load_or_init(nexus_dir);
555
556        assert_eq!(loaded.hot_cache.entries.len(), 1);
557        assert_eq!(loaded.hot_cache.entries[0].content, "test");
558    }
559
560    #[test]
561    fn test_load_or_init_handles_missing_and_corrupt() {
562        let dir = tempdir().unwrap();
563        let nexus_dir = dir.path();
564
565        // Missing
566        let cache = CognitiveCache::load_or_init(nexus_dir);
567        assert_eq!(cache.hot_cache.entries.len(), 0);
568
569        // Corrupt
570        let cache_dir = nexus_dir.join("cache");
571        std::fs::create_dir_all(&cache_dir).unwrap();
572        std::fs::write(cache_dir.join("hot.json"), "invalid json").unwrap();
573
574        let cache = CognitiveCache::load_or_init(nexus_dir);
575        assert_eq!(cache.hot_cache.entries.len(), 0);
576    }
577}