Skip to main content

nexus_memory_hooks/
enrichment.rs

1//! LLM enrichment for hook-derived memory candidates
2//!
3//! Uses the nexus-llm crate to categorize, rewrite, and comment on
4//! memory candidates before persistence.
5
6use std::sync::Arc;
7
8use nexus_llm::{
9    create_client_auto_with_fallback, ChatMessage, GenerateParams, LlmClient, LlmClientJson,
10};
11use serde::{Deserialize, Serialize};
12use tracing::{debug, info};
13
14use crate::candidate::MemoryCandidate;
15use crate::claude_payload::NormalizedHookEvent;
16
17/// Result of LLM enrichment for a single memory candidate.
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct EnrichedMemory {
20    /// Whether this memory should be stored
21    pub store: bool,
22    /// Nexus category for the memory
23    pub category: String,
24    /// Rewritten, standalone, retrieval-friendly content
25    #[serde(alias = "memory")]
26    pub memory_text: String,
27    /// Labels for retrieval
28    pub labels: Vec<String>,
29    /// Optional Memory Lane type
30    #[serde(rename = "memory_lane_type")]
31    pub memory_lane_type: Option<String>,
32    /// Model-authored comment explaining why this is worth keeping
33    pub comment: String,
34    /// Model confidence in the enrichment
35    pub confidence: f32,
36}
37
38/// Result of enriching a batch of candidates.
39#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct EnrichmentBatchResult {
41    #[serde(alias = "accepted_memories", default)]
42    pub memories: Vec<EnrichedMemory>,
43}
44
45/// System prompt for enrichment LLM
46const ENRICHMENT_SYSTEM_PROMPT: &str = r#"You are enriching agent hook events into durable memories for a retrieval system.
47
48Decide whether each candidate is worth storing.
49
50Keep information that is durable, decision-relevant, preference-revealing, specification-bearing, contextual in a useful way, session-significant, or contains learned patterns.
51
52Be permissive: when in doubt, store it. The retrieval system benefits from having more context available.
53
54Allowed categories:
55- general
56- facts
57- preferences
58- context
59- specifications
60- session
61
62For each accepted memory:
63- rewrite the memory into a standalone retrieval-friendly sentence or short paragraph
64- assign exactly one allowed category
65- produce 2-5 labels
66- optionally assign a memory_lane_type from: correction, decision, commitment, insight, learning, confidence, pattern_seed, cross_agent, workflow_note, gap
67- produce a comment explaining why the memory is worth keeping
68- set store to true
69
70Only reject candidates that are truly meaningless noise (e.g., "ls" output with no interesting files, empty responses).
71
72Return strict JSON only. No markdown fences."#;
73
74/// Enrichment service for memory candidates.
75pub struct EnrichmentService {
76    client: Arc<dyn LlmClient>,
77    model_name: String,
78}
79
80impl EnrichmentService {
81    /// Create a new enrichment service using the configured LLM client.
82    pub fn new() -> anyhow::Result<Self> {
83        let client = create_client_auto_with_fallback()?;
84        let model_name = client.model_name();
85
86        info!("EnrichmentService initialized with model: {}", model_name);
87
88        Ok(Self { client, model_name })
89    }
90
91    /// Get the model name being used for enrichment.
92    pub fn model_name(&self) -> &str {
93        &self.model_name
94    }
95
96    /// Enrich a batch of memory candidates using the LLM.
97    ///
98    /// Sends all candidates to the LLM in a single call for efficiency,
99    /// receiving structured enrichment decisions back.
100    pub async fn enrich_candidates(
101        &self,
102        candidates: &[MemoryCandidate],
103        event: &NormalizedHookEvent,
104    ) -> anyhow::Result<EnrichmentBatchResult> {
105        if candidates.is_empty() {
106            debug!("No candidates to enrich");
107            return Ok(EnrichmentBatchResult {
108                memories: Vec::new(),
109            });
110        }
111
112        let user_payload = self.build_user_payload(candidates, event);
113        let user_message = serde_json::to_string_pretty(&user_payload)?;
114
115        debug!(
116            "Enriching {} candidates with event: {}",
117            candidates.len(),
118            event.event_name
119        );
120
121        let params = GenerateParams {
122            messages: vec![
123                ChatMessage::system(ENRICHMENT_SYSTEM_PROMPT),
124                ChatMessage::user(user_message),
125            ],
126            max_tokens: 4096,
127            temperature: 0.3,
128            json_mode: true,
129        };
130
131        let result: EnrichmentBatchResult = self.client.generate_json(params).await?;
132
133        info!(
134            "Enrichment complete: {} memories returned",
135            result.memories.len()
136        );
137
138        Ok(result)
139    }
140
141    /// Build the user payload for the enrichment LLM call.
142    fn build_user_payload(
143        &self,
144        candidates: &[MemoryCandidate],
145        event: &NormalizedHookEvent,
146    ) -> serde_json::Value {
147        // Build event context
148        let event_obj = serde_json::json!({
149            "agent": event.agent,
150            "event_name": event.event_name,
151            "tool_name": event.tool_name,
152            "session_id": event.session_id,
153            "turn_id": event.turn_id,
154        });
155
156        // Build candidates array
157        let candidates_array: Vec<serde_json::Value> = candidates
158            .iter()
159            .map(|c| {
160                serde_json::json!({
161                    "candidate_id": c.candidate_id,
162                    "signal_score": c.signal_score,
163                    "provisional_category": c.provisional_category,
164                    "memory_text": c.memory_text,
165                    "evidence": c.evidence,
166                    "labels": c.labels,
167                })
168            })
169            .collect();
170
171        serde_json::json!({
172            "event": event_obj,
173            "candidates": candidates_array,
174        })
175    }
176}
177
178#[cfg(test)]
179mod tests {
180    use super::*;
181    use chrono::Utc;
182    use serde_json::json;
183
184    #[test]
185    fn test_enriched_memory_serialization() {
186        let memory = EnrichedMemory {
187            store: true,
188            category: "preferences".to_string(),
189            memory_text: "User prefers Rust over C++ for systems programming".to_string(),
190            labels: vec![
191                "rust".to_string(),
192                "cpp".to_string(),
193                "preferences".to_string(),
194            ],
195            memory_lane_type: Some("preference".to_string()),
196            comment: "Clear preference statement affecting future tool selection".to_string(),
197            confidence: 0.9,
198        };
199
200        let serialized = serde_json::to_string(&memory).unwrap();
201        let deserialized: EnrichedMemory = serde_json::from_str(&serialized).unwrap();
202
203        assert!(deserialized.store);
204        assert_eq!(deserialized.category, "preferences");
205        assert_eq!(deserialized.labels.len(), 3);
206    }
207
208    #[test]
209    fn test_enrichment_batch_result_serialization() {
210        let result = EnrichmentBatchResult {
211            memories: vec![EnrichedMemory {
212                store: true,
213                category: "facts".to_string(),
214                memory_text: "Project uses SQLite for persistence".to_string(),
215                labels: vec!["sqlite".to_string(), "database".to_string()],
216                memory_lane_type: None,
217                comment: "Architecture fact".to_string(),
218                confidence: 0.95,
219            }],
220        };
221
222        let serialized = serde_json::to_string(&result).unwrap();
223        let deserialized: EnrichmentBatchResult = serde_json::from_str(&serialized).unwrap();
224
225        assert_eq!(deserialized.memories.len(), 1);
226        assert_eq!(deserialized.memories[0].category, "facts");
227    }
228
229    #[test]
230    fn test_enrichment_batch_result_accepts_accepted_memories_alias() {
231        let payload = json!({
232            "accepted_memories": [
233                {
234                    "store": true,
235                    "category": "facts",
236                    "memory_text": "Project uses SQLite for persistence",
237                    "labels": ["sqlite", "database"],
238                    "memory_lane_type": null,
239                    "comment": "Architecture fact",
240                    "confidence": 0.95
241                }
242            ]
243        });
244
245        let deserialized: EnrichmentBatchResult =
246            serde_json::from_value(payload).expect("deserialize alias payload");
247
248        assert_eq!(deserialized.memories.len(), 1);
249        assert_eq!(deserialized.memories[0].category, "facts");
250    }
251
252    #[test]
253    fn test_build_user_payload() {
254        let service = EnrichmentService {
255            client: Arc::new(MockLlmClient::new()),
256            model_name: "test-model".to_string(),
257        };
258
259        let candidates = vec![MemoryCandidate {
260            candidate_id: "test-1".to_string(),
261            source_event_name: "test_event".to_string(),
262            source_agent: "test-agent".to_string(),
263            signal_score: 0.8,
264            provisional_category: Some("preferences".to_string()),
265            memory_text: "Test memory".to_string(),
266            evidence: json!({"key": "value"}),
267            labels: vec!["test".to_string()],
268        }];
269
270        let event = NormalizedHookEvent {
271            agent: "claude-code".to_string(),
272            event_name: "test_event".to_string(),
273            observed_at: Utc::now(),
274            session_id: Some("session-123".to_string()),
275            turn_id: Some("turn-456".to_string()),
276            cwd: Some("/home/user/project".to_string()),
277            tool_name: Some("test_tool".to_string()),
278            tool_input: None,
279            tool_response_text: None,
280            assistant_message_text: None,
281            user_message_text: None,
282            observer: Some("claude-code".to_string()),
283            subject: Some("claude-code".to_string()),
284            session_key: Some("session-123".to_string()),
285            raw_payload: json!({}),
286        };
287
288        let payload = service.build_user_payload(&candidates, &event);
289
290        assert_eq!(payload["event"]["agent"], "claude-code");
291        assert_eq!(payload["event"]["event_name"], "test_event");
292        assert_eq!(payload["candidates"].as_array().unwrap().len(), 1);
293        assert_eq!(payload["candidates"][0]["candidate_id"], "test-1");
294    }
295
296    // Mock LLM client for testing
297    struct MockLlmClient;
298
299    impl MockLlmClient {
300        fn new() -> Self {
301            Self
302        }
303    }
304
305    #[async_trait::async_trait]
306    impl LlmClient for MockLlmClient {
307        async fn generate(
308            &self,
309            _params: GenerateParams,
310        ) -> nexus_llm::Result<nexus_llm::GenerateResponse> {
311            Ok(nexus_llm::GenerateResponse {
312                content: "{}".to_string(),
313                model: "mock-model".to_string(),
314                usage: None,
315            })
316        }
317
318        fn provider_name(&self) -> String {
319            "mock".to_string()
320        }
321
322        fn model_name(&self) -> String {
323            "mock-model".to_string()
324        }
325    }
326}