Skip to main content

nexus_memory_hooks/
enrichment.rs

1//! LLM enrichment for hook-derived memory candidates
2//!
3//! Uses the nexus-llm crate to categorize, rewrite, and comment on
4//! memory candidates before persistence.
5
6use std::sync::Arc;
7
8use nexus_llm::{
9    create_client_auto_with_fallback, ChatMessage, GenerateParams, LlmClient, LlmClientJson,
10};
11use serde::{Deserialize, Serialize};
12use tracing::{debug, info};
13
14use crate::candidate::MemoryCandidate;
15use crate::claude_payload::NormalizedHookEvent;
16
17/// Result of LLM enrichment for a single memory candidate.
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct EnrichedMemory {
20    /// Whether this memory should be stored
21    pub store: bool,
22    /// Nexus category for the memory
23    pub category: String,
24    /// Rewritten, standalone, retrieval-friendly content
25    #[serde(alias = "memory")]
26    pub memory_text: String,
27    /// Labels for retrieval
28    pub labels: Vec<String>,
29    /// Optional Memory Lane type
30    #[serde(rename = "memory_lane_type")]
31    pub memory_lane_type: Option<String>,
32    /// Model-authored comment explaining why this is worth keeping
33    pub comment: String,
34    /// Model confidence in the enrichment
35    pub confidence: f32,
36}
37
38/// Result of enriching a batch of candidates.
39#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct EnrichmentBatchResult {
41    #[serde(alias = "accepted_memories", default)]
42    pub memories: Vec<EnrichedMemory>,
43}
44
45/// System prompt for enrichment LLM
46const ENRICHMENT_SYSTEM_PROMPT: &str = r#"You are enriching agent hook events into durable memories for a retrieval system.
47
48Decide whether each candidate is worth storing.
49
50Only keep information that is durable, decision-relevant, preference-revealing, specification-bearing, contextual in a useful way, or session-significant.
51
52Allowed categories:
53- general
54- facts
55- preferences
56- context
57- specifications
58- session
59
60For each accepted memory:
61- rewrite the memory into a standalone retrieval-friendly sentence or short paragraph
62- assign exactly one allowed category
63- produce 2-5 labels
64- optionally assign a memory_lane_type from: correction, decision, commitment, insight, learning, confidence, pattern_seed, cross_agent, workflow_note, gap
65- produce a comment explaining why the memory is worth keeping
66
67The comment must be model-authored and should explain why the memory is worth keeping, what retrieval value it has, or how it should be interpreted later.
68
69Reject low-signal operational noise.
70
71Return strict JSON only. No markdown fences."#;
72
73/// Enrichment service for memory candidates.
74pub struct EnrichmentService {
75    client: Arc<dyn LlmClient>,
76    model_name: String,
77}
78
79impl EnrichmentService {
80    /// Create a new enrichment service using the configured LLM client.
81    pub fn new() -> anyhow::Result<Self> {
82        let client = create_client_auto_with_fallback()?;
83        let model_name = client.model_name();
84
85        info!("EnrichmentService initialized with model: {}", model_name);
86
87        Ok(Self { client, model_name })
88    }
89
90    /// Get the model name being used for enrichment.
91    pub fn model_name(&self) -> &str {
92        &self.model_name
93    }
94
95    /// Enrich a batch of memory candidates using the LLM.
96    ///
97    /// Sends all candidates to the LLM in a single call for efficiency,
98    /// receiving structured enrichment decisions back.
99    pub async fn enrich_candidates(
100        &self,
101        candidates: &[MemoryCandidate],
102        event: &NormalizedHookEvent,
103    ) -> anyhow::Result<EnrichmentBatchResult> {
104        if candidates.is_empty() {
105            debug!("No candidates to enrich");
106            return Ok(EnrichmentBatchResult {
107                memories: Vec::new(),
108            });
109        }
110
111        let user_payload = self.build_user_payload(candidates, event);
112        let user_message = serde_json::to_string_pretty(&user_payload)?;
113
114        debug!(
115            "Enriching {} candidates with event: {}",
116            candidates.len(),
117            event.event_name
118        );
119
120        let params = GenerateParams {
121            messages: vec![
122                ChatMessage::system(ENRICHMENT_SYSTEM_PROMPT),
123                ChatMessage::user(user_message),
124            ],
125            max_tokens: 4096,
126            temperature: 0.3,
127            json_mode: true,
128        };
129
130        let result: EnrichmentBatchResult = self.client.generate_json(params).await?;
131
132        info!(
133            "Enrichment complete: {} memories returned",
134            result.memories.len()
135        );
136
137        Ok(result)
138    }
139
140    /// Build the user payload for the enrichment LLM call.
141    fn build_user_payload(
142        &self,
143        candidates: &[MemoryCandidate],
144        event: &NormalizedHookEvent,
145    ) -> serde_json::Value {
146        // Build event context
147        let event_obj = serde_json::json!({
148            "agent": event.agent,
149            "event_name": event.event_name,
150            "tool_name": event.tool_name,
151            "session_id": event.session_id,
152            "turn_id": event.turn_id,
153        });
154
155        // Build candidates array
156        let candidates_array: Vec<serde_json::Value> = candidates
157            .iter()
158            .map(|c| {
159                serde_json::json!({
160                    "candidate_id": c.candidate_id,
161                    "signal_score": c.signal_score,
162                    "provisional_category": c.provisional_category,
163                    "memory_text": c.memory_text,
164                    "evidence": c.evidence,
165                    "labels": c.labels,
166                })
167            })
168            .collect();
169
170        serde_json::json!({
171            "event": event_obj,
172            "candidates": candidates_array,
173        })
174    }
175}
176
177#[cfg(test)]
178mod tests {
179    use super::*;
180    use chrono::Utc;
181    use serde_json::json;
182
183    #[test]
184    fn test_enriched_memory_serialization() {
185        let memory = EnrichedMemory {
186            store: true,
187            category: "preferences".to_string(),
188            memory_text: "User prefers Rust over C++ for systems programming".to_string(),
189            labels: vec![
190                "rust".to_string(),
191                "cpp".to_string(),
192                "preferences".to_string(),
193            ],
194            memory_lane_type: Some("preference".to_string()),
195            comment: "Clear preference statement affecting future tool selection".to_string(),
196            confidence: 0.9,
197        };
198
199        let serialized = serde_json::to_string(&memory).unwrap();
200        let deserialized: EnrichedMemory = serde_json::from_str(&serialized).unwrap();
201
202        assert!(deserialized.store);
203        assert_eq!(deserialized.category, "preferences");
204        assert_eq!(deserialized.labels.len(), 3);
205    }
206
207    #[test]
208    fn test_enrichment_batch_result_serialization() {
209        let result = EnrichmentBatchResult {
210            memories: vec![EnrichedMemory {
211                store: true,
212                category: "facts".to_string(),
213                memory_text: "Project uses SQLite for persistence".to_string(),
214                labels: vec!["sqlite".to_string(), "database".to_string()],
215                memory_lane_type: None,
216                comment: "Architecture fact".to_string(),
217                confidence: 0.95,
218            }],
219        };
220
221        let serialized = serde_json::to_string(&result).unwrap();
222        let deserialized: EnrichmentBatchResult = serde_json::from_str(&serialized).unwrap();
223
224        assert_eq!(deserialized.memories.len(), 1);
225        assert_eq!(deserialized.memories[0].category, "facts");
226    }
227
228    #[test]
229    fn test_enrichment_batch_result_accepts_accepted_memories_alias() {
230        let payload = json!({
231            "accepted_memories": [
232                {
233                    "store": true,
234                    "category": "facts",
235                    "memory_text": "Project uses SQLite for persistence",
236                    "labels": ["sqlite", "database"],
237                    "memory_lane_type": null,
238                    "comment": "Architecture fact",
239                    "confidence": 0.95
240                }
241            ]
242        });
243
244        let deserialized: EnrichmentBatchResult =
245            serde_json::from_value(payload).expect("deserialize alias payload");
246
247        assert_eq!(deserialized.memories.len(), 1);
248        assert_eq!(deserialized.memories[0].category, "facts");
249    }
250
251    #[test]
252    fn test_build_user_payload() {
253        let service = EnrichmentService {
254            client: Arc::new(MockLlmClient::new()),
255            model_name: "test-model".to_string(),
256        };
257
258        let candidates = vec![MemoryCandidate {
259            candidate_id: "test-1".to_string(),
260            source_event_name: "test_event".to_string(),
261            source_agent: "test-agent".to_string(),
262            signal_score: 0.8,
263            provisional_category: Some("preferences".to_string()),
264            memory_text: "Test memory".to_string(),
265            evidence: json!({"key": "value"}),
266            labels: vec!["test".to_string()],
267        }];
268
269        let event = NormalizedHookEvent {
270            agent: "claude-code".to_string(),
271            event_name: "test_event".to_string(),
272            observed_at: Utc::now(),
273            session_id: Some("session-123".to_string()),
274            turn_id: Some("turn-456".to_string()),
275            cwd: Some("/home/user/project".to_string()),
276            tool_name: Some("test_tool".to_string()),
277            tool_input: None,
278            tool_response_text: None,
279            assistant_message_text: None,
280            user_message_text: None,
281            raw_payload: json!({}),
282        };
283
284        let payload = service.build_user_payload(&candidates, &event);
285
286        assert_eq!(payload["event"]["agent"], "claude-code");
287        assert_eq!(payload["event"]["event_name"], "test_event");
288        assert_eq!(payload["candidates"].as_array().unwrap().len(), 1);
289        assert_eq!(payload["candidates"][0]["candidate_id"], "test-1");
290    }
291
292    // Mock LLM client for testing
293    struct MockLlmClient;
294
295    impl MockLlmClient {
296        fn new() -> Self {
297            Self
298        }
299    }
300
301    #[async_trait::async_trait]
302    impl LlmClient for MockLlmClient {
303        async fn generate(
304            &self,
305            _params: GenerateParams,
306        ) -> nexus_llm::Result<nexus_llm::GenerateResponse> {
307            Ok(nexus_llm::GenerateResponse {
308                content: "{}".to_string(),
309                model: "mock-model".to_string(),
310                usage: None,
311            })
312        }
313
314        fn provider_name(&self) -> String {
315            "mock".to_string()
316        }
317
318        fn model_name(&self) -> String {
319            "mock-model".to_string()
320        }
321    }
322}