1use chrono::{DateTime, Utc};
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::path::Path;
7use std::time::Duration;
8use tracing::{debug, warn};
9
10use crate::context_builder::ColdRecall;
11use crate::error::AgentError;
12use nexus_core::{EmbeddingService, Memory, ProjectIdentity};
13use nexus_storage::repository::MemoryRepository;
14use nexus_vectors::{SearchOptions, SemanticSearch, VectorEntry};
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
18pub enum ConfidenceTier {
19 Whisper,
21 Clear,
23 Loud,
25}
26
27impl ConfidenceTier {
28 pub fn from_score(score: f32) -> Self {
30 if score >= 0.85 {
31 ConfidenceTier::Loud
32 } else if score >= 0.72 {
33 ConfidenceTier::Clear
34 } else {
35 ConfidenceTier::Whisper
36 }
37 }
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct HotCacheEntry {
43 pub memory_id: i64,
44 pub content: String,
45 pub relevance_score: f32,
46 pub tier: ConfidenceTier,
47 pub promoted_at: DateTime<Utc>,
48 pub last_surfaced: DateTime<Utc>,
49 pub hot_streak: u32,
50 pub pinned: bool,
51 pub source_agent: Option<String>,
52}
53
54impl HotCacheEntry {
55 pub fn eviction_score(&self) -> f32 {
58 if self.pinned {
59 return f32::MAX;
60 }
61
62 let now = Utc::now();
63 let age_secs = now
64 .signed_duration_since(self.last_surfaced)
65 .num_seconds()
66 .max(1) as f32;
67
68 let age_days = (age_secs / 86400.0).min(80.0);
70 let recency_penalty = age_days.exp();
71
72 let frequency_boost = (self.hot_streak as f32).ln().max(1.0);
74
75 (self.relevance_score * frequency_boost) / recency_penalty
76 }
77}
78
79#[derive(Debug, Clone, Serialize, Deserialize, Default)]
81pub struct HotCache {
82 pub entries: Vec<HotCacheEntry>,
83 pub last_updated: Option<DateTime<Utc>>,
84 pub last_session_id: Option<String>,
85}
86
87impl HotCache {
88 pub fn promote(&mut self, entry: HotCacheEntry, max_entries: usize) -> bool {
90 if let Some(existing) = self
91 .entries
92 .iter_mut()
93 .find(|e| e.memory_id == entry.memory_id)
94 {
95 existing.content = entry.content;
96 existing.relevance_score = entry.relevance_score;
97 existing.tier = entry.tier;
98 existing.hot_streak += 1;
99 existing.last_surfaced = Utc::now();
100 existing.pinned = existing.pinned || entry.pinned; return true;
102 }
103
104 if self.entries.len() >= max_entries {
105 let mut candidates: Vec<(usize, f32)> = self
107 .entries
108 .iter()
109 .enumerate()
110 .filter(|(_, e)| !e.pinned)
111 .map(|(i, e)| (i, e.eviction_score()))
112 .collect();
113
114 if !candidates.is_empty() {
115 candidates
117 .sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
118 self.entries.remove(candidates[0].0);
119 } else {
120 return false;
122 }
123 }
124
125 self.entries.push(entry);
126 self.last_updated = Some(Utc::now());
127 true
128 }
129}
130
131#[derive(Debug, Clone, Serialize, Deserialize)]
133pub struct ColdIndexEntry {
134 pub memory_id: i64,
135 pub project_relevance: f32,
136 pub last_surfaced: Option<DateTime<Utc>>,
137}
138
139#[derive(Debug, Clone, Serialize, Deserialize, Default)]
141pub struct ColdCacheIndex {
142 pub entries: Vec<ColdIndexEntry>,
143 pub last_reindexed: Option<DateTime<Utc>>,
144}
145
146#[derive(Debug, Clone, Serialize, Deserialize, Default)]
148pub struct CognitiveCache {
149 pub hot_cache: HotCache,
150 pub cold_index: ColdCacheIndex,
151}
152
153impl CognitiveCache {
154 pub async fn morning_recall(
156 &self,
157 project: &ProjectIdentity,
158 namespace_id: i64,
159 memory_repo: &MemoryRepository,
160 embedding_service: Option<&dyn EmbeddingService>,
161 ) -> Vec<ColdRecall> {
162 let _start = std::time::Instant::now();
163 let query_string = format!(
164 "{} {} project context",
165 project.display_name,
166 project.git_remote.as_deref().unwrap_or("")
167 );
168 let hot_ids: std::collections::HashSet<i64> =
169 self.hot_cache.entries.iter().map(|e| e.memory_id).collect();
170
171 let mut results = Vec::new();
172
173 if let Some(service) = embedding_service {
174 match tokio::time::timeout(Duration::from_millis(2000), async {
175 if let Ok(embedding) = service.embed(&query_string).await {
176 let filters = nexus_storage::repository::ListMemoryFilters {
178 category: None,
179 since: None,
180 until: None,
181 content_like: None,
182 include_raw: false,
183 limit: 50,
184 offset: 0,
185 };
186
187 if let Ok(memories) = memory_repo.list_filtered(namespace_id, filters).await {
188 let entries: Vec<VectorEntry> = memories
189 .into_iter()
190 .filter_map(|m| {
191 m.content_embedding.as_ref().map(|emb| {
192 VectorEntry::new(
193 m.id,
194 emb.clone(),
195 m.category.to_string(),
196 namespace_id,
197 )
198 })
199 })
200 .collect();
201
202 let search = SemanticSearch::new();
203 let options = SearchOptions::with_limit(20).with_threshold(0.65);
204
205 if let Ok((search_results, _)) =
206 search.search(&embedding, &entries, &options)
207 {
208 let filtered_results: Vec<_> = search_results
210 .into_iter()
211 .filter(|r| !hot_ids.contains(&r.id))
212 .take(10)
213 .collect();
214
215 let ids: Vec<i64> = filtered_results.iter().map(|r| r.id).collect();
216
217 let memories = match memory_repo.get_by_ids(&ids).await {
218 Ok(m) => m,
219 Err(e) => {
220 tracing::warn!("get_by_ids failed in morning_recall: {}", e);
221 Vec::new()
222 }
223 };
224
225 let memory_by_id: HashMap<i64, Memory> =
227 memories.into_iter().map(|m| (m.id, m)).collect();
228
229 let mut recalls = Vec::new();
230 for r in filtered_results {
231 if let Some(m) = memory_by_id.get(&r.id) {
232 recalls.push(ColdRecall {
233 memory_id: r.id,
234 content: m.content.clone(),
235 relevance_score: r.score,
236 tier: ConfidenceTier::from_score(r.score),
237 });
238 }
239 }
240 return Ok::<Vec<ColdRecall>, AgentError>(recalls);
241 }
242 }
243 }
244 Ok(Vec::new())
245 })
246 .await
247 {
248 Ok(Ok(recalls)) => results = recalls,
249 Ok(Err(e)) => warn!("Morning recall vector search failed: {}", e),
250 Err(_) => warn!("Morning recall vector search timed out"),
251 }
252 }
253
254 if results.is_empty() {
255 let filters = nexus_storage::repository::ListMemoryFilters {
256 category: None,
257 since: None,
258 until: None,
259 content_like: Some(&project.display_name),
260 include_raw: false,
261 limit: 10,
262 offset: 0,
263 };
264
265 if let Ok(memories) = memory_repo.list_filtered(namespace_id, filters).await {
266 results = memories
267 .into_iter()
268 .filter(|m| !hot_ids.contains(&m.id))
269 .take(10)
270 .map(|m| ColdRecall {
271 memory_id: m.id,
272 content: m.content,
273 relevance_score: 0.65,
274 tier: ConfidenceTier::Whisper,
275 })
276 .collect();
277 }
278
279 if results.is_empty() {
281 let mut sorted_cold: Vec<_> = self
283 .cold_index
284 .entries
285 .iter()
286 .filter(|e| !hot_ids.contains(&e.memory_id) && e.project_relevance >= 0.3)
287 .collect();
288 sorted_cold.sort_by(|a, b| {
289 b.project_relevance
290 .partial_cmp(&a.project_relevance)
291 .unwrap_or(std::cmp::Ordering::Equal)
292 });
293 let cold_ids: Vec<i64> = sorted_cold.iter().take(10).map(|e| e.memory_id).collect();
294
295 if !cold_ids.is_empty() {
296 match memory_repo.get_by_ids(&cold_ids).await {
297 Ok(cold_memories) => {
298 let cold_memory_by_id: HashMap<i64, Memory> =
299 cold_memories.into_iter().map(|m| (m.id, m)).collect();
300
301 for cold_entry in sorted_cold.iter().take(10) {
302 if let Some(m) = cold_memory_by_id.get(&cold_entry.memory_id) {
303 results.push(ColdRecall {
304 memory_id: m.id,
305 content: m.content.clone(),
306 relevance_score: cold_entry.project_relevance,
307 tier: ConfidenceTier::from_score(
308 cold_entry.project_relevance,
309 ),
310 });
311 }
312 }
313 }
314 Err(e) => {
315 debug!("get_by_ids failed for cold_index in morning_recall: {}", e);
316 }
317 }
318 }
319 }
320 }
321
322 debug!(
323 "Morning recall found {} items in {:?}",
324 results.len(),
325 _start.elapsed()
326 );
327 results
328 }
329
330 pub fn load_or_init(nexus_dir: &Path) -> Self {
332 let cache_dir = nexus_dir.join("cache");
333 let hot_path = cache_dir.join("hot.json");
334 let cold_path = cache_dir.join("cold_index.json");
335
336 let hot_cache = if hot_path.exists() {
337 match std::fs::read_to_string(&hot_path) {
338 Ok(s) => match serde_json::from_str(&s) {
339 Ok(cache) => cache,
340 Err(e) => {
341 tracing::warn!(
342 path = %hot_path.display(),
343 error = %e,
344 "Failed to parse hot cache; using defaults"
345 );
346 HotCache::default()
347 }
348 },
349 Err(e) => {
350 tracing::warn!(
351 path = %hot_path.display(),
352 error = %e,
353 "Failed to read hot cache; using defaults"
354 );
355 HotCache::default()
356 }
357 }
358 } else {
359 HotCache::default()
360 };
361
362 let cold_index = if cold_path.exists() {
363 match std::fs::read_to_string(&cold_path) {
364 Ok(s) => match serde_json::from_str(&s) {
365 Ok(idx) => idx,
366 Err(e) => {
367 tracing::warn!(
368 path = %cold_path.display(),
369 error = %e,
370 "Failed to parse cold index; using defaults"
371 );
372 ColdCacheIndex::default()
373 }
374 },
375 Err(e) => {
376 tracing::warn!(
377 path = %cold_path.display(),
378 error = %e,
379 "Failed to read cold index; using defaults"
380 );
381 ColdCacheIndex::default()
382 }
383 }
384 } else {
385 ColdCacheIndex::default()
386 };
387
388 Self {
389 hot_cache,
390 cold_index,
391 }
392 }
393
394 pub fn save(&self, nexus_dir: &Path) -> std::io::Result<()> {
396 let cache_dir = nexus_dir.join("cache");
397 std::fs::create_dir_all(&cache_dir)?;
398
399 let hot_json = serde_json::to_string_pretty(&self.hot_cache)?;
400 nexus_core::fsutil::atomic_write(&cache_dir.join("hot.json"), &hot_json)?;
401
402 let cold_json = serde_json::to_string_pretty(&self.cold_index)?;
403 nexus_core::fsutil::atomic_write(&cache_dir.join("cold_index.json"), &cold_json)?;
404
405 Ok(())
406 }
407}
408
409#[cfg(test)]
410mod tests {
411 use super::*;
412 use tempfile::tempdir;
413
414 #[test]
415 fn test_confidence_tier_boundaries() {
416 assert_eq!(ConfidenceTier::from_score(0.85), ConfidenceTier::Loud);
417 assert_eq!(ConfidenceTier::from_score(0.84), ConfidenceTier::Clear);
418 assert_eq!(ConfidenceTier::from_score(0.72), ConfidenceTier::Clear);
419 assert_eq!(ConfidenceTier::from_score(0.71), ConfidenceTier::Whisper);
420 assert_eq!(ConfidenceTier::from_score(0.50), ConfidenceTier::Whisper);
421 }
422
423 #[test]
424 fn test_hot_cache_promote_and_evict() {
425 let mut hot = HotCache::default();
426 let max = 2;
427
428 let e1 = HotCacheEntry {
429 memory_id: 1,
430 content: "e1".into(),
431 relevance_score: 0.9,
432 tier: ConfidenceTier::Loud,
433 promoted_at: Utc::now(),
434 last_surfaced: Utc::now(),
435 hot_streak: 1,
436 pinned: false,
437 source_agent: None,
438 };
439 let e2 = HotCacheEntry {
440 memory_id: 2,
441 content: "e2".into(),
442 relevance_score: 0.8,
443 tier: ConfidenceTier::Clear,
444 promoted_at: Utc::now(),
445 last_surfaced: Utc::now(),
446 hot_streak: 1,
447 pinned: false,
448 source_agent: None,
449 };
450 let e3 = HotCacheEntry {
451 memory_id: 3,
452 content: "e3".into(),
453 relevance_score: 0.95,
454 tier: ConfidenceTier::Loud,
455 promoted_at: Utc::now(),
456 last_surfaced: Utc::now(),
457 hot_streak: 1,
458 pinned: false,
459 source_agent: None,
460 };
461
462 hot.promote(e1, max);
463 hot.promote(e2, max);
464 assert_eq!(hot.entries.len(), 2);
465
466 hot.promote(e3, max);
467 assert_eq!(hot.entries.len(), 2);
468 assert!(hot.entries.iter().any(|e| e.memory_id == 1));
470 assert!(hot.entries.iter().any(|e| e.memory_id == 3));
471 }
472
473 #[test]
474 fn test_hot_cache_never_evicts_pinned() {
475 let mut hot = HotCache::default();
476 let max = 1;
477
478 let pinned = HotCacheEntry {
479 memory_id: 1,
480 content: "pinned".into(),
481 relevance_score: 0.1,
482 tier: ConfidenceTier::Whisper,
483 promoted_at: Utc::now(),
484 last_surfaced: Utc::now(),
485 hot_streak: 1,
486 pinned: true,
487 source_agent: None,
488 };
489 let high = HotCacheEntry {
490 memory_id: 2,
491 content: "high".into(),
492 relevance_score: 0.99,
493 tier: ConfidenceTier::Loud,
494 promoted_at: Utc::now(),
495 last_surfaced: Utc::now(),
496 hot_streak: 1,
497 pinned: false,
498 source_agent: None,
499 };
500
501 hot.promote(pinned, max);
502 hot.promote(high, max);
503
504 assert_eq!(hot.entries.len(), 1);
505 assert_eq!(hot.entries[0].memory_id, 1);
506 }
507
508 #[test]
509 fn test_cache_persistence_roundtrip() {
510 let dir = tempdir().unwrap();
511 let nexus_dir = dir.path();
512
513 let mut cache = CognitiveCache::default();
514 cache.hot_cache.entries.push(HotCacheEntry {
515 memory_id: 1,
516 content: "test".into(),
517 relevance_score: 0.9,
518 tier: ConfidenceTier::Loud,
519 promoted_at: Utc::now(),
520 last_surfaced: Utc::now(),
521 hot_streak: 1,
522 pinned: false,
523 source_agent: None,
524 });
525
526 cache.save(nexus_dir).unwrap();
527 let loaded = CognitiveCache::load_or_init(nexus_dir);
528
529 assert_eq!(loaded.hot_cache.entries.len(), 1);
530 assert_eq!(loaded.hot_cache.entries[0].content, "test");
531 }
532
533 #[test]
534 fn test_load_or_init_handles_missing_and_corrupt() {
535 let dir = tempdir().unwrap();
536 let nexus_dir = dir.path();
537
538 let cache = CognitiveCache::load_or_init(nexus_dir);
540 assert_eq!(cache.hot_cache.entries.len(), 0);
541
542 let cache_dir = nexus_dir.join("cache");
544 std::fs::create_dir_all(&cache_dir).unwrap();
545 std::fs::write(cache_dir.join("hot.json"), "invalid json").unwrap();
546
547 let cache = CognitiveCache::load_or_init(nexus_dir);
548 assert_eq!(cache.hot_cache.entries.len(), 0);
549 }
550}