Skip to main content

nexus_memory_hooks/
retrieval.rs

1//! Subconscious retrieval engine — surfaces relevant memories for injection.
2//!
3//! Queries cognitive cache, soul.md, and embedding search to build an XML
4//! context payload for hook stdout injection (UserPromptSubmit / PreToolUse).
5
6use std::path::{Path, PathBuf};
7
8use serde::{Deserialize, Serialize};
9use tracing::debug;
10
11use nexus_agent::cognitive_cache::{CognitiveCache, ConfidenceTier};
12use nexus_agent::soul::soul_path;
13use nexus_core::Config;
14
15use crate::sync_state::{self, SyncState};
16
17/// Operating mode for the subconscious pipeline.
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
19#[serde(rename_all = "lowercase")]
20pub enum SubconsciousMode {
21    /// Inject minimal guidance via stdout XML (default).
22    #[default]
23    Whisper,
24    /// Inject full soul.md + memory blocks + active guidance.
25    Full,
26    /// Disable all subconscious hooks.
27    Off,
28}
29
30impl SubconsciousMode {
31    /// Read mode from `NEXUS_SUBCONSCIOUS_MODE` env var.
32    pub fn from_env() -> Self {
33        match std::env::var("NEXUS_SUBCONSCIOUS_MODE")
34            .unwrap_or_default()
35            .to_lowercase()
36            .as_str()
37        {
38            "full" => SubconsciousMode::Full,
39            "off" => SubconsciousMode::Off,
40            _ => SubconsciousMode::Whisper,
41        }
42    }
43}
44
45/// Result of a memory retrieval operation.
46#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct RetrievalResult {
48    /// Soul.md content (truncated to budget).
49    pub soul_content: Option<String>,
50    /// Top relevant memories by embedding similarity.
51    pub recalled: Vec<RecalledMemory>,
52    /// Hot cache entries that have changed since last sync.
53    pub active_guidance: Vec<String>,
54    /// Metadata line for the context header.
55    pub stats: RetrievalStats,
56}
57
58#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct RecalledMemory {
60    pub content: String,
61    pub relevance: f32,
62    pub tier: ConfidenceTier,
63    pub source: String,
64}
65
66#[derive(Debug, Clone, Serialize, Deserialize)]
67pub struct RetrievalStats {
68    pub total_memories: usize,
69    pub hot_cache_entries: usize,
70    pub soul_md_exists: bool,
71    pub soul_md_age_minutes: Option<i64>,
72}
73
74/// The retrieval engine drives subconscious memory injection.
75pub struct RetrievalEngine {
76    project_root: PathBuf,
77    mode: SubconsciousMode,
78    /// Maximum tokens for the entire injection payload.
79    token_budget: usize,
80}
81
82/// Token budget allocations for different injection sections.
83struct TokenBudgets {
84    soul: usize,
85    recall: usize,
86    guidance: usize,
87}
88
89impl RetrievalEngine {
90    /// Create a new retrieval engine for a project.
91    pub fn new(project_root: &Path, _config: Config) -> Self {
92        let mode = SubconsciousMode::from_env();
93        let token_budget = match mode {
94            SubconsciousMode::Whisper => 8192,
95            SubconsciousMode::Full => 16384,
96            SubconsciousMode::Off => 0,
97        };
98        Self {
99            project_root: project_root.to_path_buf(),
100            mode,
101            token_budget,
102        }
103    }
104
105    /// Retrieve memories relevant to a prompt for UserPromptSubmit injection.
106    pub async fn retrieve_for_prompt(
107        &self,
108        prompt: &str,
109        sync_state: &SyncState,
110    ) -> RetrievalResult {
111        let stats = self.gather_stats();
112        let soul_content = self.load_soul_content();
113        let hot_cache = self.load_hot_cache();
114        let recalled = self.search_memories(prompt, &hot_cache).await;
115        let active_guidance = self.compute_guidance(&hot_cache, sync_state);
116
117        RetrievalResult {
118            soul_content,
119            recalled,
120            active_guidance,
121            stats,
122        }
123    }
124
125    /// Lightweight check for updates since last sync (PreToolUse hook).
126    pub fn check_for_updates(&self, sync_state: &mut SyncState) -> Option<String> {
127        let soul_content = self.load_soul_content();
128        let soul_hash = sync_state::soul_content_hash(soul_content.as_deref().unwrap_or(""));
129        let hot_cache = self.load_hot_cache();
130        let hot_cache_ids: Vec<String> = hot_cache
131            .hot_cache
132            .entries
133            .iter()
134            .map(|e| e.memory_id.to_string())
135            .collect();
136        let hot_cache_hash = sync_state::hot_cache_hash(&hot_cache_ids);
137
138        if !sync_state.has_updates(
139            &soul_hash,
140            hot_cache.hot_cache.entries.len(),
141            &hot_cache_hash,
142        ) {
143            return None;
144        }
145
146        // Build a lightweight delta injection
147        let mut parts = Vec::new();
148
149        if soul_hash != sync_state.last_soul_hash {
150            if let Some(soul) = soul_content {
151                let truncated = truncate_to_chars(&soul, 8192);
152                parts.push(format!(
153                    "<soul_update>\n{}\n</soul_update>",
154                    escape_xml(&truncated)
155                ));
156            } else {
157                parts.push("<soul_update deleted=\"true\" />".to_string());
158            }
159        }
160
161        if hot_cache.hot_cache.entries.len() > sync_state.last_hot_cache_count {
162            // Cache grew: list the newly promoted entries
163            let new_count = hot_cache
164                .hot_cache
165                .entries
166                .len()
167                .saturating_sub(sync_state.last_hot_cache_count);
168            let new_entries: Vec<_> = hot_cache
169                .hot_cache
170                .entries
171                .iter()
172                .rev()
173                .take(new_count)
174                .map(|e| {
175                    let tier = match e.tier {
176                        ConfidenceTier::Loud => "LOUD",
177                        ConfidenceTier::Clear => "CLEAR",
178                        ConfidenceTier::Whisper => "WHISPER",
179                    };
180                    format!(
181                        "[{}] {}",
182                        tier,
183                        escape_xml(&truncate_to_chars(&e.content, 8192))
184                    )
185                })
186                .collect();
187            parts.push(format!(
188                "<cache_promotions count=\"{new_count}\">\n{}\n</cache_promotions>",
189                new_entries.join("\n")
190            ));
191        } else if hot_cache_hash != sync_state.last_hot_cache_hash {
192            // Cache content changed without net growth (eviction at capacity,
193            // or count decreased). Emit a summary of the current state.
194            let entries: Vec<_> = hot_cache
195                .hot_cache
196                .entries
197                .iter()
198                .rev()
199                .take(5)
200                .map(|e| {
201                    let tier = match e.tier {
202                        ConfidenceTier::Loud => "LOUD",
203                        ConfidenceTier::Clear => "CLEAR",
204                        ConfidenceTier::Whisper => "WHISPER",
205                    };
206                    format!(
207                        "[{}] {}",
208                        tier,
209                        escape_xml(&truncate_to_chars(&e.content, 8192))
210                    )
211                })
212                .collect();
213            parts.push(format!(
214                "<cache_update count=\"{}\">\n{}\n</cache_update>",
215                hot_cache.hot_cache.entries.len(),
216                entries.join("\n")
217            ));
218        }
219
220        if parts.is_empty() {
221            return None;
222        }
223
224        // Advance the sync state watermark to prevent re-emission
225        sync_state.advance(
226            soul_hash,
227            hot_cache.hot_cache.entries.len(),
228            hot_cache_hash,
229            None,
230        );
231
232        Some(format!(
233            "<nexus_delta>\n{}\n</nexus_delta>",
234            parts.join("\n")
235        ))
236    }
237
238    /// Format a retrieval result as XML for stdout injection.
239    pub fn format_for_stdout(&self, result: &RetrievalResult) -> String {
240        if self.mode == SubconsciousMode::Off {
241            return String::new();
242        }
243
244        let budgets = self.compute_budgets();
245        let mut sections = Vec::new();
246
247        // Context header
248        let memory_count = result.stats.total_memories;
249        let hot_count = result.stats.hot_cache_entries;
250        let soul_status = if result.stats.soul_md_exists {
251            "synthesized"
252        } else {
253            "not yet generated"
254        };
255
256        sections.push(format!(
257            "<nexus_context>\n\
258             Subconscious memory active. {memory_count} memories indexed, \
259             {hot_count} in hot cache, soul.md {soul_status}.\n\
260             </nexus_context>"
261        ));
262
263        // Soul section (now included in both Whisper and Full modes)
264        if let Some(ref soul) = result.soul_content {
265            let truncated = truncate_to_chars(soul, budgets.soul * 4);
266            sections.push(format!(
267                "<nexus_soul>\n{}\n</nexus_soul>",
268                escape_xml(&truncated)
269            ));
270        }
271
272        // Recall section
273        if !result.recalled.is_empty() {
274            let mut entries = Vec::new();
275            for mem in &result.recalled {
276                let tier = match mem.tier {
277                    ConfidenceTier::Loud => "LOUD",
278                    ConfidenceTier::Clear => "CLEAR",
279                    ConfidenceTier::Whisper => "WHISPER",
280                };
281                let truncated = truncate_to_chars(
282                    &mem.content,
283                    budgets.recall * 4 / result.recalled.len().max(1),
284                );
285                entries.push(format!(
286                    "<memory relevance=\"{:.2}\" tier=\"{tier}\" source=\"{}\">\n{}\n</memory>",
287                    mem.relevance,
288                    escape_xml(&mem.source),
289                    escape_xml(&truncated)
290                ));
291            }
292            sections.push(format!(
293                "<nexus_recall>\n{}\n</nexus_recall>",
294                entries.join("\n")
295            ));
296        }
297
298        // Active guidance section
299        if !result.active_guidance.is_empty() {
300            let truncated_guidance: Vec<_> = result
301                .active_guidance
302                .iter()
303                .map(|g| {
304                    escape_xml(&truncate_to_chars(
305                        g,
306                        budgets.guidance * 4 / result.active_guidance.len().max(1),
307                    ))
308                })
309                .collect();
310            sections.push(format!(
311                "<nexus_guidance>\n{}\n</nexus_guidance>",
312                truncated_guidance.join("\n")
313            ));
314        }
315
316        sections.join("\n\n")
317    }
318
319    /// Format the initial session-start injection.
320    pub fn format_session_start(
321        &self,
322        hot_cache: &CognitiveCache,
323        soul_content: Option<&str>,
324    ) -> String {
325        if self.mode == SubconsciousMode::Off {
326            return String::new();
327        }
328
329        let mut parts = Vec::new();
330
331        parts.push(format!(
332            "<nexus_context>\n\
333             Subconscious memory active. {} entries in hot cache.\n\
334             Soul.md {}.\n\
335             </nexus_context>",
336            hot_cache.hot_cache.entries.len(),
337            if soul_content.is_some() {
338                "loaded"
339            } else {
340                "not yet generated"
341            }
342        ));
343
344        // Show soul content in both Whisper and Full modes (key difference from before)
345        if let Some(soul) = soul_content {
346            let truncated = truncate_to_chars(soul, 8192);
347            parts.push(format!(
348                "<nexus_soul>\n{}\n</nexus_soul>",
349                escape_xml(&truncated)
350            ));
351        }
352
353        if self.mode == SubconsciousMode::Full {
354            // Show all hot cache entries in full mode
355            if !hot_cache.hot_cache.entries.is_empty() {
356                let entries: Vec<_> = hot_cache
357                    .hot_cache
358                    .entries
359                    .iter()
360                    .map(|e| {
361                        let tier = match e.tier {
362                            ConfidenceTier::Loud => "LOUD",
363                            ConfidenceTier::Clear => "CLEAR",
364                            ConfidenceTier::Whisper => "WHISPER",
365                        };
366                        format!(
367                            "[{tier}] {}",
368                            escape_xml(&truncate_to_chars(&e.content, 8192))
369                        )
370                    })
371                    .collect();
372                parts.push(format!(
373                    "<nexus_hot_cache>\n{}\n</nexus_hot_cache>",
374                    entries.join("\n")
375                ));
376            }
377        } else {
378            // Whisper mode: show top 10 by relevance with expanded content
379            let mut sorted = hot_cache.hot_cache.entries.clone();
380            sorted.sort_by(|a, b| {
381                b.relevance_score
382                    .partial_cmp(&a.relevance_score)
383                    .unwrap_or(std::cmp::Ordering::Equal)
384            });
385            let top: Vec<_> = sorted
386                .iter()
387                .take(10)
388                .map(|e| {
389                    let tier = match e.tier {
390                        ConfidenceTier::Loud => "LOUD",
391                        ConfidenceTier::Clear => "CLEAR",
392                        ConfidenceTier::Whisper => "WHISPER",
393                    };
394                    format!(
395                        "[{tier}] {}",
396                        escape_xml(&truncate_to_chars(&e.content, 8192))
397                    )
398                })
399                .collect();
400            if !top.is_empty() {
401                parts.push(format!(
402                    "<nexus_whisper>\n{}\n</nexus_whisper>",
403                    top.join("\n")
404                ));
405            }
406        }
407
408        parts.join("\n\n")
409    }
410
411    // ── Internal helpers ──────────────────────────────────────────────
412
413    fn gather_stats(&self) -> RetrievalStats {
414        let hot_cache = self.load_hot_cache();
415        let soul_path = soul_path();
416
417        let (soul_md_exists, soul_md_age_minutes) = if soul_path.exists() {
418            let age = std::fs::metadata(&soul_path)
419                .ok()
420                .and_then(|m| m.modified().ok())
421                .map(|modified| {
422                    let modified: chrono::DateTime<chrono::Local> = modified.into();
423                    chrono::Utc::now()
424                        .signed_duration_since(modified.with_timezone(&chrono::Utc))
425                        .num_minutes()
426                });
427            (true, age)
428        } else {
429            (false, None)
430        };
431
432        // For total_memories, we'd need a DB query; approximate from hot+cold cache
433        let total = hot_cache.hot_cache.entries.len() + hot_cache.cold_index.entries.len();
434
435        RetrievalStats {
436            total_memories: total,
437            hot_cache_entries: hot_cache.hot_cache.entries.len(),
438            soul_md_exists,
439            soul_md_age_minutes,
440        }
441    }
442
443    pub fn load_soul_content(&self) -> Option<String> {
444        let path = soul_path();
445        if !path.exists() {
446            return None;
447        }
448        match std::fs::read_to_string(&path) {
449            Ok(content) => {
450                if content.trim().is_empty() {
451                    None
452                } else {
453                    Some(content)
454                }
455            }
456            Err(e) => {
457                debug!("Failed to read soul.md: {e}");
458                None
459            }
460        }
461    }
462
463    fn load_hot_cache(&self) -> CognitiveCache {
464        let nexus_dir = self.project_root.join(".nexus");
465        CognitiveCache::load_or_init(&nexus_dir)
466    }
467
468    /// Hot-cache-only search. Prompt-based semantic retrieval is wired in the CLI
469    /// layer (`subconscious.rs`) which has async DB access.
470    async fn search_memories(
471        &self,
472        _prompt: &str,
473        hot_cache: &CognitiveCache,
474    ) -> Vec<RecalledMemory> {
475        let mut entries: Vec<_> = hot_cache
476            .hot_cache
477            .entries
478            .iter()
479            .filter(|e| e.relevance_score >= 0.5)
480            .map(|e| RecalledMemory {
481                content: e.content.clone(),
482                relevance: e.relevance_score,
483                tier: e.tier,
484                source: "hot_cache".to_string(),
485            })
486            .collect();
487
488        entries.sort_by(|a, b| {
489            b.relevance
490                .partial_cmp(&a.relevance)
491                .unwrap_or(std::cmp::Ordering::Equal)
492        });
493        entries.truncate(5);
494        entries
495    }
496
497    fn compute_guidance(&self, hot_cache: &CognitiveCache, sync_state: &SyncState) -> Vec<String> {
498        // Find entries surfaced since last sync
499        let mut guidance = Vec::new();
500        for entry in &hot_cache.hot_cache.entries {
501            if entry.last_surfaced > sync_state.last_sync_timestamp {
502                let tier = match entry.tier {
503                    ConfidenceTier::Loud => "LOUD",
504                    ConfidenceTier::Clear => "CLEAR",
505                    ConfidenceTier::Whisper => "WHISPER",
506                };
507                guidance.push(format!(
508                    "[{tier}] {}",
509                    truncate_to_chars(&entry.content, 8192)
510                ));
511            }
512        }
513        guidance.truncate(10);
514        guidance
515    }
516
517    fn compute_budgets(&self) -> TokenBudgets {
518        let total = self.token_budget;
519        match self.mode {
520            SubconsciousMode::Whisper => TokenBudgets {
521                soul: total / 4,
522                recall: total * 3 / 8,
523                guidance: total * 3 / 8,
524            },
525            SubconsciousMode::Full => TokenBudgets {
526                soul: total / 2,
527                recall: total / 4,
528                guidance: total / 4,
529            },
530            SubconsciousMode::Off => TokenBudgets {
531                soul: 0,
532                recall: 0,
533                guidance: 0,
534            },
535        }
536    }
537}
538
539/// Truncate text to approximately `max_chars` bytes, preserving UTF-8 and word boundaries.
540fn truncate_to_chars(text: &str, max_chars: usize) -> String {
541    if text.len() <= max_chars {
542        return text.to_string();
543    }
544
545    const ELLIPSIS: &str = "…";
546    if max_chars <= ELLIPSIS.len() {
547        return String::new();
548    }
549
550    // Find a char boundary at or before max_chars - ellipsis length
551    let mut end = max_chars - ELLIPSIS.len();
552    while end > 0 && !text.is_char_boundary(end) {
553        end -= 1;
554    }
555
556    // Try to break at a word boundary
557    if let Some(pos) = text[..end].rfind(' ') {
558        end = pos;
559    }
560
561    format!("{}…", &text[..end])
562}
563
564/// Escape special XML characters in a string.
565fn escape_xml(s: &str) -> String {
566    let mut out = String::with_capacity(s.len());
567    for c in s.chars() {
568        match c {
569            '&' => out.push_str("&amp;"),
570            '<' => out.push_str("&lt;"),
571            '>' => out.push_str("&gt;"),
572            '"' => out.push_str("&quot;"),
573            '\'' => out.push_str("&apos;"),
574            _ => out.push(c),
575        }
576    }
577    out
578}
579
580#[cfg(test)]
581mod tests {
582    use super::*;
583    use tempfile::TempDir;
584
585    fn test_config() -> Config {
586        Config::default()
587    }
588
589    #[test]
590    fn mode_from_env_default_is_whisper() {
591        // Ensure no env var set
592        std::env::remove_var("NEXUS_SUBCONSCIOUS_MODE");
593        assert_eq!(SubconsciousMode::from_env(), SubconsciousMode::Whisper);
594    }
595
596    #[test]
597    fn truncate_preserves_short_text() {
598        assert_eq!(truncate_to_chars("hello", 10), "hello");
599    }
600
601    #[test]
602    fn truncate_truncates_long_text() {
603        let result = truncate_to_chars("hello world this is a test of truncation", 15);
604        assert!(result.ends_with('…'));
605        assert!(result.len() < 20);
606    }
607
608    #[test]
609    fn truncate_handles_multibyte() {
610        let text = "héllo wörld";
611        let result = truncate_to_chars(text, 5);
612        assert!(result.len() < 15);
613        // Should not panic and should produce valid UTF-8
614        assert!(result.ends_with('…') || result == "héllo");
615    }
616
617    #[test]
618    fn escape_xml_handles_special_chars() {
619        assert_eq!(
620            escape_xml("a<b>c&d\"e'f"),
621            "a&lt;b&gt;c&amp;d&quot;e&apos;f"
622        );
623    }
624
625    #[test]
626    fn format_off_mode_returns_empty() {
627        let dir = TempDir::new().unwrap();
628        let mut engine = RetrievalEngine::new(dir.path(), test_config());
629        engine.mode = SubconsciousMode::Off;
630        let result = RetrievalResult {
631            soul_content: Some("test".to_string()),
632            recalled: vec![],
633            active_guidance: vec![],
634            stats: RetrievalStats {
635                total_memories: 10,
636                hot_cache_entries: 3,
637                soul_md_exists: true,
638                soul_md_age_minutes: Some(5),
639            },
640        };
641        assert!(engine.format_for_stdout(&result).is_empty());
642    }
643
644    #[test]
645    fn format_whisper_mode_includes_context() {
646        let dir = TempDir::new().unwrap();
647        let mut engine = RetrievalEngine::new(dir.path(), test_config());
648        engine.mode = SubconsciousMode::Whisper;
649        let result = RetrievalResult {
650            soul_content: None,
651            recalled: vec![RecalledMemory {
652                content: "test memory".to_string(),
653                relevance: 0.9,
654                tier: ConfidenceTier::Loud,
655                source: "hot_cache".to_string(),
656            }],
657            active_guidance: vec![],
658            stats: RetrievalStats {
659                total_memories: 10,
660                hot_cache_entries: 3,
661                soul_md_exists: false,
662                soul_md_age_minutes: None,
663            },
664        };
665        let output = engine.format_for_stdout(&result);
666        assert!(output.contains("<nexus_context>"));
667        assert!(output.contains("<nexus_recall>"));
668        assert!(!output.contains("<nexus_soul>"));
669    }
670
671    #[test]
672    fn format_full_mode_includes_soul() {
673        let dir = TempDir::new().unwrap();
674        let mut engine = RetrievalEngine::new(dir.path(), test_config());
675        engine.mode = SubconsciousMode::Full;
676        let result = RetrievalResult {
677            soul_content: Some("I am a helpful assistant".to_string()),
678            recalled: vec![],
679            active_guidance: vec![],
680            stats: RetrievalStats {
681                total_memories: 10,
682                hot_cache_entries: 0,
683                soul_md_exists: true,
684                soul_md_age_minutes: Some(5),
685            },
686        };
687        let output = engine.format_for_stdout(&result);
688        assert!(output.contains("<nexus_soul>"));
689        assert!(output.contains("I am a helpful assistant"));
690    }
691
692    #[test]
693    fn check_for_updates_returns_none_when_unchanged() {
694        let dir = TempDir::new().unwrap();
695        let engine = RetrievalEngine::new(dir.path(), test_config());
696        let mut state = SyncState::new("test");
697        let soul_content = engine.load_soul_content();
698        state.last_soul_hash = sync_state::soul_content_hash(soul_content.as_deref().unwrap_or(""));
699        state.last_hot_cache_count = engine.load_hot_cache().hot_cache.entries.len();
700        let hot_cache_ids: Vec<String> = engine
701            .load_hot_cache()
702            .hot_cache
703            .entries
704            .iter()
705            .map(|e| e.memory_id.to_string())
706            .collect();
707        state.last_hot_cache_hash = sync_state::hot_cache_hash(&hot_cache_ids);
708        assert!(engine.check_for_updates(&mut state).is_none());
709    }
710
711    #[test]
712    fn session_start_format_contains_context() {
713        let dir = TempDir::new().unwrap();
714        let mut engine = RetrievalEngine::new(dir.path(), test_config());
715        engine.mode = SubconsciousMode::Whisper;
716        let cache = CognitiveCache::default();
717        let output = engine.format_session_start(&cache, None);
718        assert!(output.contains("<nexus_context>"));
719    }
720}