Skip to main content

nexus_memory_hooks/
retrieval.rs

1//! Subconscious retrieval engine — surfaces relevant memories for injection.
2//!
3//! Queries cognitive cache, soul.md, and embedding search to build an XML
4//! context payload for hook stdout injection (UserPromptSubmit / PreToolUse).
5
6use std::path::{Path, PathBuf};
7
8use serde::{Deserialize, Serialize};
9use tracing::debug;
10
11use nexus_agent::cognitive_cache::{CognitiveCache, ConfidenceTier};
12use nexus_agent::soul::soul_path;
13use nexus_core::Config;
14
15use crate::sync_state::{self, SyncState};
16
17/// Operating mode for the subconscious pipeline.
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
19#[serde(rename_all = "lowercase")]
20pub enum SubconsciousMode {
21    /// Inject minimal guidance via stdout XML (default).
22    #[default]
23    Whisper,
24    /// Inject full soul.md + memory blocks + active guidance.
25    Full,
26    /// Disable all subconscious hooks.
27    Off,
28}
29
30impl SubconsciousMode {
31    /// Read mode from `NEXUS_SUBCONSCIOUS_MODE` env var.
32    pub fn from_env() -> Self {
33        match std::env::var("NEXUS_SUBCONSCIOUS_MODE")
34            .unwrap_or_default()
35            .to_lowercase()
36            .as_str()
37        {
38            "full" => SubconsciousMode::Full,
39            "off" => SubconsciousMode::Off,
40            _ => SubconsciousMode::Whisper,
41        }
42    }
43}
44
45/// Result of a memory retrieval operation.
46#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct RetrievalResult {
48    /// Soul.md content (truncated to budget).
49    pub soul_content: Option<String>,
50    /// Top relevant memories by embedding similarity.
51    pub recalled: Vec<RecalledMemory>,
52    /// Hot cache entries that have changed since last sync.
53    pub active_guidance: Vec<String>,
54    /// Metadata line for the context header.
55    pub stats: RetrievalStats,
56}
57
58#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct RecalledMemory {
60    pub content: String,
61    pub relevance: f32,
62    pub tier: ConfidenceTier,
63    pub source: String,
64}
65
66#[derive(Debug, Clone, Serialize, Deserialize)]
67pub struct RetrievalStats {
68    pub total_memories: usize,
69    pub hot_cache_entries: usize,
70    pub soul_md_exists: bool,
71    pub soul_md_age_minutes: Option<i64>,
72}
73
74/// The retrieval engine drives subconscious memory injection.
75pub struct RetrievalEngine {
76    project_root: PathBuf,
77    mode: SubconsciousMode,
78    /// Maximum tokens for the entire injection payload.
79    token_budget: usize,
80}
81
82/// Token budget allocations for different injection sections.
83struct TokenBudgets {
84    soul: usize,
85    recall: usize,
86    guidance: usize,
87}
88
89impl RetrievalEngine {
90    /// Create a new retrieval engine for a project.
91    pub fn new(project_root: &Path, _config: Config) -> Self {
92        let mode = SubconsciousMode::from_env();
93        let token_budget = match mode {
94            SubconsciousMode::Whisper => 512,
95            SubconsciousMode::Full => 1024,
96            SubconsciousMode::Off => 0,
97        };
98        Self {
99            project_root: project_root.to_path_buf(),
100            mode,
101            token_budget,
102        }
103    }
104
105    /// Retrieve memories relevant to a prompt for UserPromptSubmit injection.
106    pub async fn retrieve_for_prompt(
107        &self,
108        prompt: &str,
109        sync_state: &SyncState,
110    ) -> RetrievalResult {
111        let stats = self.gather_stats();
112        let soul_content = self.load_soul_content();
113        let hot_cache = self.load_hot_cache();
114        let recalled = self.search_memories(prompt, &hot_cache).await;
115        let active_guidance = self.compute_guidance(&hot_cache, sync_state);
116
117        RetrievalResult {
118            soul_content,
119            recalled,
120            active_guidance,
121            stats,
122        }
123    }
124
125    /// Lightweight check for updates since last sync (PreToolUse hook).
126    pub fn check_for_updates(&self, sync_state: &mut SyncState) -> Option<String> {
127        let soul_content = self.load_soul_content();
128        let soul_hash = sync_state::soul_content_hash(soul_content.as_deref().unwrap_or(""));
129        let hot_cache = self.load_hot_cache();
130        let hot_cache_ids: Vec<String> = hot_cache
131            .hot_cache
132            .entries
133            .iter()
134            .map(|e| e.memory_id.to_string())
135            .collect();
136        let hot_cache_hash = sync_state::hot_cache_hash(&hot_cache_ids);
137
138        if !sync_state.has_updates(
139            &soul_hash,
140            hot_cache.hot_cache.entries.len(),
141            &hot_cache_hash,
142        ) {
143            return None;
144        }
145
146        // Build a lightweight delta injection
147        let mut parts = Vec::new();
148
149        if soul_hash != sync_state.last_soul_hash {
150            if let Some(soul) = soul_content {
151                let truncated = truncate_to_chars(&soul, 300);
152                parts.push(format!(
153                    "<soul_update>\n{}\n</soul_update>",
154                    escape_xml(&truncated)
155                ));
156            } else {
157                parts.push("<soul_update deleted=\"true\" />".to_string());
158            }
159        }
160
161        if hot_cache.hot_cache.entries.len() > sync_state.last_hot_cache_count {
162            // Cache grew: list the newly promoted entries
163            let new_count = hot_cache
164                .hot_cache
165                .entries
166                .len()
167                .saturating_sub(sync_state.last_hot_cache_count);
168            let new_entries: Vec<_> = hot_cache
169                .hot_cache
170                .entries
171                .iter()
172                .rev()
173                .take(new_count)
174                .map(|e| {
175                    let tier = match e.tier {
176                        ConfidenceTier::Loud => "LOUD",
177                        ConfidenceTier::Clear => "CLEAR",
178                        ConfidenceTier::Whisper => "WHISPER",
179                    };
180                    format!(
181                        "[{}] {}",
182                        tier,
183                        escape_xml(&truncate_to_chars(&e.content, 120))
184                    )
185                })
186                .collect();
187            parts.push(format!(
188                "<cache_promotions count=\"{new_count}\">\n{}\n</cache_promotions>",
189                new_entries.join("\n")
190            ));
191        } else if hot_cache_hash != sync_state.last_hot_cache_hash {
192            // Cache content changed without net growth (eviction at capacity,
193            // or count decreased). Emit a summary of the current state.
194            let entries: Vec<_> = hot_cache
195                .hot_cache
196                .entries
197                .iter()
198                .rev()
199                .take(5)
200                .map(|e| {
201                    let tier = match e.tier {
202                        ConfidenceTier::Loud => "LOUD",
203                        ConfidenceTier::Clear => "CLEAR",
204                        ConfidenceTier::Whisper => "WHISPER",
205                    };
206                    format!(
207                        "[{}] {}",
208                        tier,
209                        escape_xml(&truncate_to_chars(&e.content, 120))
210                    )
211                })
212                .collect();
213            parts.push(format!(
214                "<cache_update count=\"{}\">\n{}\n</cache_update>",
215                hot_cache.hot_cache.entries.len(),
216                entries.join("\n")
217            ));
218        }
219
220        if parts.is_empty() {
221            return None;
222        }
223
224        // Advance the sync state watermark to prevent re-emission
225        sync_state.advance(
226            soul_hash,
227            hot_cache.hot_cache.entries.len(),
228            hot_cache_hash,
229            None,
230        );
231
232        Some(format!(
233            "<nexus_delta>\n{}\n</nexus_delta>",
234            parts.join("\n")
235        ))
236    }
237
238    /// Format a retrieval result as XML for stdout injection.
239    pub fn format_for_stdout(&self, result: &RetrievalResult) -> String {
240        if self.mode == SubconsciousMode::Off {
241            return String::new();
242        }
243
244        let budgets = self.compute_budgets();
245        let mut sections = Vec::new();
246
247        // Context header
248        let memory_count = result.stats.total_memories;
249        let hot_count = result.stats.hot_cache_entries;
250        let soul_status = if result.stats.soul_md_exists {
251            "synthesized"
252        } else {
253            "not yet generated"
254        };
255
256        sections.push(format!(
257            "<nexus_context>\n\
258             Subconscious memory active. {memory_count} memories indexed, \
259             {hot_count} in hot cache, soul.md {soul_status}.\n\
260             </nexus_context>"
261        ));
262
263        // Soul section (full mode only, or whisper if explicitly requested)
264        if self.mode == SubconsciousMode::Full {
265            if let Some(ref soul) = result.soul_content {
266                let truncated = truncate_to_chars(soul, budgets.soul * 4);
267                sections.push(format!(
268                    "<nexus_soul>\n{}\n</nexus_soul>",
269                    escape_xml(&truncated)
270                ));
271            }
272        }
273
274        // Recall section
275        if !result.recalled.is_empty() {
276            let mut entries = Vec::new();
277            for mem in &result.recalled {
278                let tier = match mem.tier {
279                    ConfidenceTier::Loud => "LOUD",
280                    ConfidenceTier::Clear => "CLEAR",
281                    ConfidenceTier::Whisper => "WHISPER",
282                };
283                let truncated = truncate_to_chars(
284                    &mem.content,
285                    budgets.recall * 4 / result.recalled.len().max(1),
286                );
287                entries.push(format!(
288                    "<memory relevance=\"{:.2}\" tier=\"{tier}\" source=\"{}\">\n{}\n</memory>",
289                    mem.relevance,
290                    escape_xml(&mem.source),
291                    escape_xml(&truncated)
292                ));
293            }
294            sections.push(format!(
295                "<nexus_recall>\n{}\n</nexus_recall>",
296                entries.join("\n")
297            ));
298        }
299
300        // Active guidance section
301        if !result.active_guidance.is_empty() {
302            let truncated_guidance: Vec<_> = result
303                .active_guidance
304                .iter()
305                .map(|g| {
306                    escape_xml(&truncate_to_chars(
307                        g,
308                        budgets.guidance * 4 / result.active_guidance.len().max(1),
309                    ))
310                })
311                .collect();
312            sections.push(format!(
313                "<nexus_guidance>\n{}\n</nexus_guidance>",
314                truncated_guidance.join("\n")
315            ));
316        }
317
318        sections.join("\n\n")
319    }
320
321    /// Format the initial session-start injection.
322    pub fn format_session_start(
323        &self,
324        hot_cache: &CognitiveCache,
325        soul_content: Option<&str>,
326    ) -> String {
327        if self.mode == SubconsciousMode::Off {
328            return String::new();
329        }
330
331        let mut parts = Vec::new();
332
333        parts.push(format!(
334            "<nexus_context>\n\
335             Subconscious memory active. {} entries in hot cache.\n\
336             Soul.md {}.\n\
337             </nexus_context>",
338            hot_cache.hot_cache.entries.len(),
339            if soul_content.is_some() {
340                "loaded"
341            } else {
342                "not yet generated"
343            }
344        ));
345
346        if self.mode == SubconsciousMode::Full {
347            if let Some(soul) = soul_content {
348                let truncated = truncate_to_chars(soul, 400);
349                parts.push(format!(
350                    "<nexus_soul>\n{}\n</nexus_soul>",
351                    escape_xml(&truncated)
352                ));
353            }
354
355            // Show all hot cache entries in full mode
356            if !hot_cache.hot_cache.entries.is_empty() {
357                let entries: Vec<_> = hot_cache
358                    .hot_cache
359                    .entries
360                    .iter()
361                    .map(|e| {
362                        let tier = match e.tier {
363                            ConfidenceTier::Loud => "LOUD",
364                            ConfidenceTier::Clear => "CLEAR",
365                            ConfidenceTier::Whisper => "WHISPER",
366                        };
367                        format!(
368                            "[{tier}] {}",
369                            escape_xml(&truncate_to_chars(&e.content, 120))
370                        )
371                    })
372                    .collect();
373                parts.push(format!(
374                    "<nexus_hot_cache>\n{}\n</nexus_hot_cache>",
375                    entries.join("\n")
376                ));
377            }
378        } else {
379            // Whisper mode: show top 3 by relevance
380            let mut sorted = hot_cache.hot_cache.entries.clone();
381            sorted.sort_by(|a, b| {
382                b.relevance_score
383                    .partial_cmp(&a.relevance_score)
384                    .unwrap_or(std::cmp::Ordering::Equal)
385            });
386            let top: Vec<_> = sorted
387                .iter()
388                .take(3)
389                .map(|e| {
390                    let tier = match e.tier {
391                        ConfidenceTier::Loud => "LOUD",
392                        ConfidenceTier::Clear => "CLEAR",
393                        ConfidenceTier::Whisper => "WHISPER",
394                    };
395                    format!(
396                        "[{tier}] {}",
397                        escape_xml(&truncate_to_chars(&e.content, 80))
398                    )
399                })
400                .collect();
401            if !top.is_empty() {
402                parts.push(format!(
403                    "<nexus_whisper>\n{}\n</nexus_whisper>",
404                    top.join("\n")
405                ));
406            }
407        }
408
409        parts.join("\n\n")
410    }
411
412    // ── Internal helpers ──────────────────────────────────────────────
413
414    fn gather_stats(&self) -> RetrievalStats {
415        let hot_cache = self.load_hot_cache();
416        let soul_path = soul_path();
417
418        let (soul_md_exists, soul_md_age_minutes) = if soul_path.exists() {
419            let age = std::fs::metadata(&soul_path)
420                .ok()
421                .and_then(|m| m.modified().ok())
422                .map(|modified| {
423                    let modified: chrono::DateTime<chrono::Local> = modified.into();
424                    chrono::Utc::now()
425                        .signed_duration_since(modified.with_timezone(&chrono::Utc))
426                        .num_minutes()
427                });
428            (true, age)
429        } else {
430            (false, None)
431        };
432
433        // For total_memories, we'd need a DB query; approximate from hot+cold cache
434        let total = hot_cache.hot_cache.entries.len() + hot_cache.cold_index.entries.len();
435
436        RetrievalStats {
437            total_memories: total,
438            hot_cache_entries: hot_cache.hot_cache.entries.len(),
439            soul_md_exists,
440            soul_md_age_minutes,
441        }
442    }
443
444    pub fn load_soul_content(&self) -> Option<String> {
445        let path = soul_path();
446        if !path.exists() {
447            return None;
448        }
449        match std::fs::read_to_string(&path) {
450            Ok(content) => {
451                if content.trim().is_empty() {
452                    None
453                } else {
454                    Some(content)
455                }
456            }
457            Err(e) => {
458                debug!("Failed to read soul.md: {e}");
459                None
460            }
461        }
462    }
463
464    fn load_hot_cache(&self) -> CognitiveCache {
465        let nexus_dir = self.project_root.join(".nexus");
466        CognitiveCache::load_or_init(&nexus_dir)
467    }
468
469    /// Hot-cache-only search. Prompt-based semantic retrieval is wired in the CLI
470    /// layer (`subconscious.rs`) which has async DB access.
471    async fn search_memories(
472        &self,
473        _prompt: &str,
474        hot_cache: &CognitiveCache,
475    ) -> Vec<RecalledMemory> {
476        let mut entries: Vec<_> = hot_cache
477            .hot_cache
478            .entries
479            .iter()
480            .filter(|e| e.relevance_score >= 0.5)
481            .map(|e| RecalledMemory {
482                content: e.content.clone(),
483                relevance: e.relevance_score,
484                tier: e.tier,
485                source: "hot_cache".to_string(),
486            })
487            .collect();
488
489        entries.sort_by(|a, b| {
490            b.relevance
491                .partial_cmp(&a.relevance)
492                .unwrap_or(std::cmp::Ordering::Equal)
493        });
494        entries.truncate(5);
495        entries
496    }
497
498    fn compute_guidance(&self, hot_cache: &CognitiveCache, sync_state: &SyncState) -> Vec<String> {
499        // Find entries surfaced since last sync
500        let mut guidance = Vec::new();
501        for entry in &hot_cache.hot_cache.entries {
502            if entry.last_surfaced > sync_state.last_sync_timestamp {
503                let tier = match entry.tier {
504                    ConfidenceTier::Loud => "LOUD",
505                    ConfidenceTier::Clear => "CLEAR",
506                    ConfidenceTier::Whisper => "WHISPER",
507                };
508                guidance.push(format!(
509                    "[{tier}] {}",
510                    truncate_to_chars(&entry.content, 120)
511                ));
512            }
513        }
514        guidance.truncate(3);
515        guidance
516    }
517
518    fn compute_budgets(&self) -> TokenBudgets {
519        let total = self.token_budget;
520        match self.mode {
521            SubconsciousMode::Whisper => TokenBudgets {
522                soul: 0, // No soul in whisper mode
523                recall: total * 3 / 4,
524                guidance: total / 4,
525            },
526            SubconsciousMode::Full => TokenBudgets {
527                soul: total / 2,
528                recall: total / 3,
529                guidance: total / 6,
530            },
531            SubconsciousMode::Off => TokenBudgets {
532                soul: 0,
533                recall: 0,
534                guidance: 0,
535            },
536        }
537    }
538}
539
540/// Truncate text to approximately `max_chars` bytes, preserving UTF-8 and word boundaries.
541fn truncate_to_chars(text: &str, max_chars: usize) -> String {
542    if text.len() <= max_chars {
543        return text.to_string();
544    }
545
546    const ELLIPSIS: &str = "…";
547    if max_chars <= ELLIPSIS.len() {
548        return String::new();
549    }
550
551    // Find a char boundary at or before max_chars - ellipsis length
552    let mut end = max_chars - ELLIPSIS.len();
553    while end > 0 && !text.is_char_boundary(end) {
554        end -= 1;
555    }
556
557    // Try to break at a word boundary
558    if let Some(pos) = text[..end].rfind(' ') {
559        end = pos;
560    }
561
562    format!("{}…", &text[..end])
563}
564
565/// Escape special XML characters in a string.
566fn escape_xml(s: &str) -> String {
567    let mut out = String::with_capacity(s.len());
568    for c in s.chars() {
569        match c {
570            '&' => out.push_str("&amp;"),
571            '<' => out.push_str("&lt;"),
572            '>' => out.push_str("&gt;"),
573            '"' => out.push_str("&quot;"),
574            '\'' => out.push_str("&apos;"),
575            _ => out.push(c),
576        }
577    }
578    out
579}
580
581#[cfg(test)]
582mod tests {
583    use super::*;
584    use tempfile::TempDir;
585
586    fn test_config() -> Config {
587        Config::default()
588    }
589
590    #[test]
591    fn mode_from_env_default_is_whisper() {
592        // Ensure no env var set
593        std::env::remove_var("NEXUS_SUBCONSCIOUS_MODE");
594        assert_eq!(SubconsciousMode::from_env(), SubconsciousMode::Whisper);
595    }
596
597    #[test]
598    fn truncate_preserves_short_text() {
599        assert_eq!(truncate_to_chars("hello", 10), "hello");
600    }
601
602    #[test]
603    fn truncate_truncates_long_text() {
604        let result = truncate_to_chars("hello world this is a test of truncation", 15);
605        assert!(result.ends_with('…'));
606        assert!(result.len() < 20);
607    }
608
609    #[test]
610    fn truncate_handles_multibyte() {
611        let text = "héllo wörld";
612        let result = truncate_to_chars(text, 5);
613        assert!(result.len() < 15);
614        // Should not panic and should produce valid UTF-8
615        assert!(result.ends_with('…') || result == "héllo");
616    }
617
618    #[test]
619    fn escape_xml_handles_special_chars() {
620        assert_eq!(
621            escape_xml("a<b>c&d\"e'f"),
622            "a&lt;b&gt;c&amp;d&quot;e&apos;f"
623        );
624    }
625
626    #[test]
627    fn format_off_mode_returns_empty() {
628        let dir = TempDir::new().unwrap();
629        let mut engine = RetrievalEngine::new(dir.path(), test_config());
630        engine.mode = SubconsciousMode::Off;
631        let result = RetrievalResult {
632            soul_content: Some("test".to_string()),
633            recalled: vec![],
634            active_guidance: vec![],
635            stats: RetrievalStats {
636                total_memories: 10,
637                hot_cache_entries: 3,
638                soul_md_exists: true,
639                soul_md_age_minutes: Some(5),
640            },
641        };
642        assert!(engine.format_for_stdout(&result).is_empty());
643    }
644
645    #[test]
646    fn format_whisper_mode_includes_context() {
647        let dir = TempDir::new().unwrap();
648        let mut engine = RetrievalEngine::new(dir.path(), test_config());
649        engine.mode = SubconsciousMode::Whisper;
650        let result = RetrievalResult {
651            soul_content: None,
652            recalled: vec![RecalledMemory {
653                content: "test memory".to_string(),
654                relevance: 0.9,
655                tier: ConfidenceTier::Loud,
656                source: "hot_cache".to_string(),
657            }],
658            active_guidance: vec![],
659            stats: RetrievalStats {
660                total_memories: 10,
661                hot_cache_entries: 3,
662                soul_md_exists: false,
663                soul_md_age_minutes: None,
664            },
665        };
666        let output = engine.format_for_stdout(&result);
667        assert!(output.contains("<nexus_context>"));
668        assert!(output.contains("<nexus_recall>"));
669        assert!(!output.contains("<nexus_soul>"));
670    }
671
672    #[test]
673    fn format_full_mode_includes_soul() {
674        let dir = TempDir::new().unwrap();
675        let mut engine = RetrievalEngine::new(dir.path(), test_config());
676        engine.mode = SubconsciousMode::Full;
677        let result = RetrievalResult {
678            soul_content: Some("I am a helpful assistant".to_string()),
679            recalled: vec![],
680            active_guidance: vec![],
681            stats: RetrievalStats {
682                total_memories: 10,
683                hot_cache_entries: 0,
684                soul_md_exists: true,
685                soul_md_age_minutes: Some(5),
686            },
687        };
688        let output = engine.format_for_stdout(&result);
689        assert!(output.contains("<nexus_soul>"));
690        assert!(output.contains("I am a helpful assistant"));
691    }
692
693    #[test]
694    fn check_for_updates_returns_none_when_unchanged() {
695        let dir = TempDir::new().unwrap();
696        let engine = RetrievalEngine::new(dir.path(), test_config());
697        let mut state = SyncState::new("test");
698        let soul_content = engine.load_soul_content();
699        state.last_soul_hash = sync_state::soul_content_hash(soul_content.as_deref().unwrap_or(""));
700        state.last_hot_cache_count = engine.load_hot_cache().hot_cache.entries.len();
701        let hot_cache_ids: Vec<String> = engine
702            .load_hot_cache()
703            .hot_cache
704            .entries
705            .iter()
706            .map(|e| e.memory_id.to_string())
707            .collect();
708        state.last_hot_cache_hash = sync_state::hot_cache_hash(&hot_cache_ids);
709        assert!(engine.check_for_updates(&mut state).is_none());
710    }
711
712    #[test]
713    fn session_start_format_contains_context() {
714        let dir = TempDir::new().unwrap();
715        let mut engine = RetrievalEngine::new(dir.path(), test_config());
716        engine.mode = SubconsciousMode::Whisper;
717        let cache = CognitiveCache::default();
718        let output = engine.format_session_start(&cache, None);
719        assert!(output.contains("<nexus_context>"));
720    }
721}