1use chrono::{DateTime, Utc};
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::path::Path;
7use std::time::Duration;
8use tracing::{debug, warn};
9
10use crate::context_builder::ColdRecall;
11use crate::error::AgentError;
12use nexus_core::{EmbeddingService, Memory, ProjectIdentity};
13use nexus_storage::repository::MemoryRepository;
14use nexus_vectors::{SearchOptions, SemanticSearch, VectorEntry};
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
18pub enum ConfidenceTier {
19 Whisper,
21 Clear,
23 Loud,
25}
26
27impl ConfidenceTier {
28 pub fn from_score(score: f32) -> Self {
30 if score >= 0.85 {
31 ConfidenceTier::Loud
32 } else if score >= 0.72 {
33 ConfidenceTier::Clear
34 } else {
35 ConfidenceTier::Whisper
36 }
37 }
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct HotCacheEntry {
43 pub memory_id: i64,
44 pub content: String,
45 pub relevance_score: f32,
46 pub tier: ConfidenceTier,
47 pub promoted_at: DateTime<Utc>,
48 pub last_surfaced: DateTime<Utc>,
49 pub hot_streak: u32,
50 pub pinned: bool,
51 pub source_agent: Option<String>,
52}
53
54impl HotCacheEntry {
55 pub fn eviction_score(&self) -> f32 {
58 if self.pinned {
59 return f32::MAX;
60 }
61
62 let now = Utc::now();
63 let age_secs = now
64 .signed_duration_since(self.last_surfaced)
65 .num_seconds()
66 .max(1) as f32;
67
68 let age_days = (age_secs / 86400.0).min(80.0);
70 let recency_penalty = age_days.exp();
71
72 let frequency_boost = (self.hot_streak as f32).ln().max(1.0);
74
75 (self.relevance_score * frequency_boost) / recency_penalty
76 }
77}
78
79#[derive(Debug, Clone, Serialize, Deserialize, Default)]
81pub struct HotCache {
82 pub entries: Vec<HotCacheEntry>,
83 pub last_updated: Option<DateTime<Utc>>,
84 pub last_session_id: Option<String>,
85}
86
87impl HotCache {
88 pub fn promote(&mut self, entry: HotCacheEntry, max_entries: usize) -> bool {
90 if let Some(existing) = self
91 .entries
92 .iter_mut()
93 .find(|e| e.memory_id == entry.memory_id)
94 {
95 existing.content = entry.content;
96 existing.relevance_score = entry.relevance_score;
97 existing.tier = entry.tier;
98 existing.hot_streak += 1;
99 existing.last_surfaced = Utc::now();
100 existing.pinned = existing.pinned || entry.pinned; return true;
102 }
103
104 if self.entries.len() >= max_entries {
105 let mut candidates: Vec<(usize, f32)> = self
107 .entries
108 .iter()
109 .enumerate()
110 .filter(|(_, e)| !e.pinned)
111 .map(|(i, e)| (i, e.eviction_score()))
112 .collect();
113
114 if !candidates.is_empty() {
115 candidates
117 .sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
118 self.entries.remove(candidates[0].0);
119 } else {
120 return false;
122 }
123 }
124
125 self.entries.push(entry);
126 self.last_updated = Some(Utc::now());
127 true
128 }
129}
130
131#[derive(Debug, Clone, Serialize, Deserialize)]
133pub struct ColdIndexEntry {
134 pub memory_id: i64,
135 pub project_relevance: f32,
136 pub last_surfaced: Option<DateTime<Utc>>,
137}
138
139#[derive(Debug, Clone, Serialize, Deserialize, Default)]
141pub struct ColdCacheIndex {
142 pub entries: Vec<ColdIndexEntry>,
143 pub last_reindexed: Option<DateTime<Utc>>,
144}
145
146#[derive(Debug, Clone, Serialize, Deserialize, Default)]
148pub struct CognitiveCache {
149 pub hot_cache: HotCache,
150 pub cold_index: ColdCacheIndex,
151}
152
153impl CognitiveCache {
154 fn is_system_memory(memory: &Memory) -> bool {
156 if let Some(obj) = memory.metadata.as_object() {
158 if obj.get("session_lifecycle").is_some() || obj.get("runtime").is_some() {
159 return true;
160 }
161 }
162 if memory
164 .labels
165 .iter()
166 .any(|l| l == "session" || l == "runtime")
167 {
168 return true;
169 }
170 false
171 }
172
173 pub async fn morning_recall(
175 &self,
176 project: &ProjectIdentity,
177 namespace_id: i64,
178 memory_repo: &MemoryRepository,
179 embedding_service: Option<&dyn EmbeddingService>,
180 ) -> Vec<ColdRecall> {
181 let _start = std::time::Instant::now();
182 let query_string = format!(
183 "{} {} project context",
184 project.display_name,
185 project.git_remote.as_deref().unwrap_or("")
186 );
187 let hot_ids: std::collections::HashSet<i64> =
188 self.hot_cache.entries.iter().map(|e| e.memory_id).collect();
189
190 let mut results = Vec::new();
191
192 if let Some(service) = embedding_service {
193 match tokio::time::timeout(Duration::from_millis(2000), async {
194 if let Ok(embedding) = service.embed(&query_string).await {
195 let filters = nexus_storage::repository::ListMemoryFilters {
197 category: None,
198 since: None,
199 until: None,
200 content_like: None,
201 include_raw: false,
202 limit: 50,
203 offset: 0,
204 };
205
206 if let Ok(memories) = memory_repo.list_filtered(namespace_id, filters).await {
207 let entries: Vec<VectorEntry> = memories
208 .into_iter()
209 .filter_map(|m| {
210 m.content_embedding.as_ref().map(|emb| {
211 VectorEntry::new(
212 m.id,
213 emb.clone(),
214 m.category.to_string(),
215 namespace_id,
216 )
217 })
218 })
219 .collect();
220
221 let search = SemanticSearch::new();
222 let options = SearchOptions::with_limit(20).with_threshold(0.65);
223
224 if let Ok((search_results, _)) =
225 search.search(&embedding, &entries, &options)
226 {
227 let filtered_results: Vec<_> = search_results
229 .into_iter()
230 .filter(|r| !hot_ids.contains(&r.id))
231 .take(10)
232 .collect();
233
234 let ids: Vec<i64> = filtered_results.iter().map(|r| r.id).collect();
235
236 let memories = match memory_repo.get_by_ids(&ids).await {
237 Ok(m) => m,
238 Err(e) => {
239 tracing::warn!("get_by_ids failed in morning_recall: {}", e);
240 Vec::new()
241 }
242 };
243
244 let memory_by_id: HashMap<i64, Memory> =
246 memories.into_iter().map(|m| (m.id, m)).collect();
247
248 let mut recalls = Vec::new();
249 for r in filtered_results {
250 if let Some(m) = memory_by_id.get(&r.id) {
251 if Self::is_system_memory(m) {
253 continue;
254 }
255 recalls.push(ColdRecall {
256 memory_id: r.id,
257 content: m.content.clone(),
258 relevance_score: r.score,
259 tier: ConfidenceTier::from_score(r.score),
260 });
261 }
262 }
263 return Ok::<Vec<ColdRecall>, AgentError>(recalls);
264 }
265 }
266 }
267 Ok(Vec::new())
268 })
269 .await
270 {
271 Ok(Ok(recalls)) => results = recalls,
272 Ok(Err(e)) => warn!("Morning recall vector search failed: {}", e),
273 Err(_) => warn!("Morning recall vector search timed out"),
274 }
275 }
276
277 if results.is_empty() {
278 let filters = nexus_storage::repository::ListMemoryFilters {
279 category: None,
280 since: None,
281 until: None,
282 content_like: Some(&project.display_name),
283 include_raw: false,
284 limit: 10,
285 offset: 0,
286 };
287
288 if let Ok(memories) = memory_repo.list_filtered(namespace_id, filters).await {
289 results = memories
290 .into_iter()
291 .filter(|m| !hot_ids.contains(&m.id) && !Self::is_system_memory(m))
292 .take(10)
293 .map(|m| ColdRecall {
294 memory_id: m.id,
295 content: m.content,
296 relevance_score: 0.65,
297 tier: ConfidenceTier::Whisper,
298 })
299 .collect();
300 }
301
302 if results.is_empty() {
304 let mut sorted_cold: Vec<_> = self
306 .cold_index
307 .entries
308 .iter()
309 .filter(|e| !hot_ids.contains(&e.memory_id) && e.project_relevance >= 0.3)
310 .collect();
311 sorted_cold.sort_by(|a, b| {
312 b.project_relevance
313 .partial_cmp(&a.project_relevance)
314 .unwrap_or(std::cmp::Ordering::Equal)
315 });
316 let cold_ids: Vec<i64> = sorted_cold.iter().take(10).map(|e| e.memory_id).collect();
317
318 if !cold_ids.is_empty() {
319 match memory_repo.get_by_ids(&cold_ids).await {
320 Ok(cold_memories) => {
321 let cold_memory_by_id: HashMap<i64, Memory> =
322 cold_memories.into_iter().map(|m| (m.id, m)).collect();
323
324 for cold_entry in sorted_cold.iter().take(10) {
325 if let Some(m) = cold_memory_by_id.get(&cold_entry.memory_id) {
326 if Self::is_system_memory(m) {
328 continue;
329 }
330 results.push(ColdRecall {
331 memory_id: m.id,
332 content: m.content.clone(),
333 relevance_score: cold_entry.project_relevance,
334 tier: ConfidenceTier::from_score(
335 cold_entry.project_relevance,
336 ),
337 });
338 }
339 }
340 }
341 Err(e) => {
342 debug!("get_by_ids failed for cold_index in morning_recall: {}", e);
343 }
344 }
345 }
346 }
347 }
348
349 debug!(
350 "Morning recall found {} items in {:?}",
351 results.len(),
352 _start.elapsed()
353 );
354 results
355 }
356
357 pub fn load_or_init(nexus_dir: &Path) -> Self {
359 let cache_dir = nexus_dir.join("cache");
360 let hot_path = cache_dir.join("hot.json");
361 let cold_path = cache_dir.join("cold_index.json");
362
363 let hot_cache = if hot_path.exists() {
364 match std::fs::read_to_string(&hot_path) {
365 Ok(s) => match serde_json::from_str(&s) {
366 Ok(cache) => cache,
367 Err(e) => {
368 tracing::warn!(
369 path = %hot_path.display(),
370 error = %e,
371 "Failed to parse hot cache; using defaults"
372 );
373 HotCache::default()
374 }
375 },
376 Err(e) => {
377 tracing::warn!(
378 path = %hot_path.display(),
379 error = %e,
380 "Failed to read hot cache; using defaults"
381 );
382 HotCache::default()
383 }
384 }
385 } else {
386 HotCache::default()
387 };
388
389 let cold_index = if cold_path.exists() {
390 match std::fs::read_to_string(&cold_path) {
391 Ok(s) => match serde_json::from_str(&s) {
392 Ok(idx) => idx,
393 Err(e) => {
394 tracing::warn!(
395 path = %cold_path.display(),
396 error = %e,
397 "Failed to parse cold index; using defaults"
398 );
399 ColdCacheIndex::default()
400 }
401 },
402 Err(e) => {
403 tracing::warn!(
404 path = %cold_path.display(),
405 error = %e,
406 "Failed to read cold index; using defaults"
407 );
408 ColdCacheIndex::default()
409 }
410 }
411 } else {
412 ColdCacheIndex::default()
413 };
414
415 Self {
416 hot_cache,
417 cold_index,
418 }
419 }
420
421 pub fn save(&self, nexus_dir: &Path) -> std::io::Result<()> {
423 let cache_dir = nexus_dir.join("cache");
424 std::fs::create_dir_all(&cache_dir)?;
425
426 let hot_json = serde_json::to_string_pretty(&self.hot_cache)?;
427 nexus_core::fsutil::atomic_write(&cache_dir.join("hot.json"), &hot_json)?;
428
429 let cold_json = serde_json::to_string_pretty(&self.cold_index)?;
430 nexus_core::fsutil::atomic_write(&cache_dir.join("cold_index.json"), &cold_json)?;
431
432 Ok(())
433 }
434}
435
436#[cfg(test)]
437mod tests {
438 use super::*;
439 use tempfile::tempdir;
440
441 #[test]
442 fn test_confidence_tier_boundaries() {
443 assert_eq!(ConfidenceTier::from_score(0.85), ConfidenceTier::Loud);
444 assert_eq!(ConfidenceTier::from_score(0.84), ConfidenceTier::Clear);
445 assert_eq!(ConfidenceTier::from_score(0.72), ConfidenceTier::Clear);
446 assert_eq!(ConfidenceTier::from_score(0.71), ConfidenceTier::Whisper);
447 assert_eq!(ConfidenceTier::from_score(0.50), ConfidenceTier::Whisper);
448 }
449
450 #[test]
451 fn test_hot_cache_promote_and_evict() {
452 let mut hot = HotCache::default();
453 let max = 2;
454
455 let e1 = HotCacheEntry {
456 memory_id: 1,
457 content: "e1".into(),
458 relevance_score: 0.9,
459 tier: ConfidenceTier::Loud,
460 promoted_at: Utc::now(),
461 last_surfaced: Utc::now(),
462 hot_streak: 1,
463 pinned: false,
464 source_agent: None,
465 };
466 let e2 = HotCacheEntry {
467 memory_id: 2,
468 content: "e2".into(),
469 relevance_score: 0.8,
470 tier: ConfidenceTier::Clear,
471 promoted_at: Utc::now(),
472 last_surfaced: Utc::now(),
473 hot_streak: 1,
474 pinned: false,
475 source_agent: None,
476 };
477 let e3 = HotCacheEntry {
478 memory_id: 3,
479 content: "e3".into(),
480 relevance_score: 0.95,
481 tier: ConfidenceTier::Loud,
482 promoted_at: Utc::now(),
483 last_surfaced: Utc::now(),
484 hot_streak: 1,
485 pinned: false,
486 source_agent: None,
487 };
488
489 hot.promote(e1, max);
490 hot.promote(e2, max);
491 assert_eq!(hot.entries.len(), 2);
492
493 hot.promote(e3, max);
494 assert_eq!(hot.entries.len(), 2);
495 assert!(hot.entries.iter().any(|e| e.memory_id == 1));
497 assert!(hot.entries.iter().any(|e| e.memory_id == 3));
498 }
499
500 #[test]
501 fn test_hot_cache_never_evicts_pinned() {
502 let mut hot = HotCache::default();
503 let max = 1;
504
505 let pinned = HotCacheEntry {
506 memory_id: 1,
507 content: "pinned".into(),
508 relevance_score: 0.1,
509 tier: ConfidenceTier::Whisper,
510 promoted_at: Utc::now(),
511 last_surfaced: Utc::now(),
512 hot_streak: 1,
513 pinned: true,
514 source_agent: None,
515 };
516 let high = HotCacheEntry {
517 memory_id: 2,
518 content: "high".into(),
519 relevance_score: 0.99,
520 tier: ConfidenceTier::Loud,
521 promoted_at: Utc::now(),
522 last_surfaced: Utc::now(),
523 hot_streak: 1,
524 pinned: false,
525 source_agent: None,
526 };
527
528 hot.promote(pinned, max);
529 hot.promote(high, max);
530
531 assert_eq!(hot.entries.len(), 1);
532 assert_eq!(hot.entries[0].memory_id, 1);
533 }
534
535 #[test]
536 fn test_cache_persistence_roundtrip() {
537 let dir = tempdir().unwrap();
538 let nexus_dir = dir.path();
539
540 let mut cache = CognitiveCache::default();
541 cache.hot_cache.entries.push(HotCacheEntry {
542 memory_id: 1,
543 content: "test".into(),
544 relevance_score: 0.9,
545 tier: ConfidenceTier::Loud,
546 promoted_at: Utc::now(),
547 last_surfaced: Utc::now(),
548 hot_streak: 1,
549 pinned: false,
550 source_agent: None,
551 });
552
553 cache.save(nexus_dir).unwrap();
554 let loaded = CognitiveCache::load_or_init(nexus_dir);
555
556 assert_eq!(loaded.hot_cache.entries.len(), 1);
557 assert_eq!(loaded.hot_cache.entries[0].content, "test");
558 }
559
560 #[test]
561 fn test_load_or_init_handles_missing_and_corrupt() {
562 let dir = tempdir().unwrap();
563 let nexus_dir = dir.path();
564
565 let cache = CognitiveCache::load_or_init(nexus_dir);
567 assert_eq!(cache.hot_cache.entries.len(), 0);
568
569 let cache_dir = nexus_dir.join("cache");
571 std::fs::create_dir_all(&cache_dir).unwrap();
572 std::fs::write(cache_dir.join("hot.json"), "invalid json").unwrap();
573
574 let cache = CognitiveCache::load_or_init(nexus_dir);
575 assert_eq!(cache.hot_cache.entries.len(), 0);
576 }
577}