Skip to main content

nexus_memory_hooks/
enrichment.rs

1//! LLM enrichment for hook-derived memory candidates
2//!
3//! Uses the nexus-llm crate to categorize, rewrite, and comment on
4//! memory candidates before persistence.
5
6use std::sync::Arc;
7
8use nexus_llm::{
9    create_client_auto_with_fallback, ChatMessage, GenerateParams, LlmClient, LlmClientJson,
10};
11use serde::{Deserialize, Serialize};
12use tracing::{debug, info};
13
14use crate::candidate::MemoryCandidate;
15use crate::claude_payload::NormalizedHookEvent;
16
17/// Result of LLM enrichment for a single memory candidate.
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct EnrichedMemory {
20    /// Whether this memory should be stored
21    pub store: bool,
22    /// Nexus category for the memory
23    pub category: String,
24    /// Rewritten, standalone, retrieval-friendly content
25    pub memory_text: String,
26    /// Labels for retrieval
27    pub labels: Vec<String>,
28    /// Optional Memory Lane type
29    #[serde(rename = "memory_lane_type")]
30    pub memory_lane_type: Option<String>,
31    /// Model-authored comment explaining why this is worth keeping
32    pub comment: String,
33    /// Model confidence in the enrichment
34    pub confidence: f32,
35}
36
37/// Result of enriching a batch of candidates.
38#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct EnrichmentBatchResult {
40    pub memories: Vec<EnrichedMemory>,
41}
42
43/// System prompt for enrichment LLM
44const ENRICHMENT_SYSTEM_PROMPT: &str = r#"You are enriching agent hook events into durable memories for a retrieval system.
45
46Decide whether each candidate is worth storing.
47
48Only keep information that is durable, decision-relevant, preference-revealing, specification-bearing, contextual in a useful way, or session-significant.
49
50Allowed categories:
51- general
52- facts
53- preferences
54- context
55- specifications
56- session
57
58For each accepted memory:
59- rewrite the memory into a standalone retrieval-friendly sentence or short paragraph
60- assign exactly one allowed category
61- produce 2-5 labels
62- optionally assign a memory_lane_type from: correction, decision, commitment, insight, learning, confidence, pattern_seed, cross_agent, workflow_note, gap
63- produce a comment explaining why the memory is worth keeping
64
65The comment must be model-authored and should explain why the memory is worth keeping, what retrieval value it has, or how it should be interpreted later.
66
67Reject low-signal operational noise.
68
69Return strict JSON only. No markdown fences."#;
70
71/// Enrichment service for memory candidates.
72pub struct EnrichmentService {
73    client: Arc<dyn LlmClient>,
74    model_name: String,
75}
76
77impl EnrichmentService {
78    /// Create a new enrichment service using the configured LLM client.
79    pub fn new() -> anyhow::Result<Self> {
80        let client = create_client_auto_with_fallback()?;
81        let model_name = client.model_name();
82
83        info!("EnrichmentService initialized with model: {}", model_name);
84
85        Ok(Self { client, model_name })
86    }
87
88    /// Get the model name being used for enrichment.
89    pub fn model_name(&self) -> &str {
90        &self.model_name
91    }
92
93    /// Enrich a batch of memory candidates using the LLM.
94    ///
95    /// Sends all candidates to the LLM in a single call for efficiency,
96    /// receiving structured enrichment decisions back.
97    pub async fn enrich_candidates(
98        &self,
99        candidates: &[MemoryCandidate],
100        event: &NormalizedHookEvent,
101    ) -> anyhow::Result<EnrichmentBatchResult> {
102        if candidates.is_empty() {
103            debug!("No candidates to enrich");
104            return Ok(EnrichmentBatchResult {
105                memories: Vec::new(),
106            });
107        }
108
109        let user_payload = self.build_user_payload(candidates, event);
110        let user_message = serde_json::to_string_pretty(&user_payload)?;
111
112        debug!(
113            "Enriching {} candidates with event: {}",
114            candidates.len(),
115            event.event_name
116        );
117
118        let params = GenerateParams {
119            messages: vec![
120                ChatMessage::system(ENRICHMENT_SYSTEM_PROMPT),
121                ChatMessage::user(user_message),
122            ],
123            max_tokens: 4096,
124            temperature: 0.3,
125            json_mode: true,
126        };
127
128        let result: EnrichmentBatchResult = self.client.generate_json(params).await?;
129
130        info!(
131            "Enrichment complete: {} memories returned",
132            result.memories.len()
133        );
134
135        Ok(result)
136    }
137
138    /// Build the user payload for the enrichment LLM call.
139    fn build_user_payload(
140        &self,
141        candidates: &[MemoryCandidate],
142        event: &NormalizedHookEvent,
143    ) -> serde_json::Value {
144        // Build event context
145        let event_obj = serde_json::json!({
146            "agent": event.agent,
147            "event_name": event.event_name,
148            "tool_name": event.tool_name,
149            "session_id": event.session_id,
150            "turn_id": event.turn_id,
151        });
152
153        // Build candidates array
154        let candidates_array: Vec<serde_json::Value> = candidates
155            .iter()
156            .map(|c| {
157                serde_json::json!({
158                    "candidate_id": c.candidate_id,
159                    "signal_score": c.signal_score,
160                    "provisional_category": c.provisional_category,
161                    "memory_text": c.memory_text,
162                    "evidence": c.evidence,
163                    "labels": c.labels,
164                })
165            })
166            .collect();
167
168        serde_json::json!({
169            "event": event_obj,
170            "candidates": candidates_array,
171        })
172    }
173}
174
175#[cfg(test)]
176mod tests {
177    use super::*;
178    use chrono::Utc;
179    use serde_json::json;
180
181    #[test]
182    fn test_enriched_memory_serialization() {
183        let memory = EnrichedMemory {
184            store: true,
185            category: "preferences".to_string(),
186            memory_text: "User prefers Rust over C++ for systems programming".to_string(),
187            labels: vec![
188                "rust".to_string(),
189                "cpp".to_string(),
190                "preferences".to_string(),
191            ],
192            memory_lane_type: Some("preference".to_string()),
193            comment: "Clear preference statement affecting future tool selection".to_string(),
194            confidence: 0.9,
195        };
196
197        let serialized = serde_json::to_string(&memory).unwrap();
198        let deserialized: EnrichedMemory = serde_json::from_str(&serialized).unwrap();
199
200        assert!(deserialized.store);
201        assert_eq!(deserialized.category, "preferences");
202        assert_eq!(deserialized.labels.len(), 3);
203    }
204
205    #[test]
206    fn test_enrichment_batch_result_serialization() {
207        let result = EnrichmentBatchResult {
208            memories: vec![EnrichedMemory {
209                store: true,
210                category: "facts".to_string(),
211                memory_text: "Project uses SQLite for persistence".to_string(),
212                labels: vec!["sqlite".to_string(), "database".to_string()],
213                memory_lane_type: None,
214                comment: "Architecture fact".to_string(),
215                confidence: 0.95,
216            }],
217        };
218
219        let serialized = serde_json::to_string(&result).unwrap();
220        let deserialized: EnrichmentBatchResult = serde_json::from_str(&serialized).unwrap();
221
222        assert_eq!(deserialized.memories.len(), 1);
223        assert_eq!(deserialized.memories[0].category, "facts");
224    }
225
226    #[test]
227    fn test_build_user_payload() {
228        let service = EnrichmentService {
229            client: Arc::new(MockLlmClient::new()),
230            model_name: "test-model".to_string(),
231        };
232
233        let candidates = vec![MemoryCandidate {
234            candidate_id: "test-1".to_string(),
235            source_event_name: "test_event".to_string(),
236            source_agent: "test-agent".to_string(),
237            signal_score: 0.8,
238            provisional_category: Some("preferences".to_string()),
239            memory_text: "Test memory".to_string(),
240            evidence: json!({"key": "value"}),
241            labels: vec!["test".to_string()],
242        }];
243
244        let event = NormalizedHookEvent {
245            agent: "claude-code".to_string(),
246            event_name: "test_event".to_string(),
247            observed_at: Utc::now(),
248            session_id: Some("session-123".to_string()),
249            turn_id: Some("turn-456".to_string()),
250            cwd: Some("/home/user/project".to_string()),
251            tool_name: Some("test_tool".to_string()),
252            tool_input: None,
253            tool_response_text: None,
254            assistant_message_text: None,
255            user_message_text: None,
256            raw_payload: json!({}),
257        };
258
259        let payload = service.build_user_payload(&candidates, &event);
260
261        assert_eq!(payload["event"]["agent"], "claude-code");
262        assert_eq!(payload["event"]["event_name"], "test_event");
263        assert_eq!(payload["candidates"].as_array().unwrap().len(), 1);
264        assert_eq!(payload["candidates"][0]["candidate_id"], "test-1");
265    }
266
267    // Mock LLM client for testing
268    struct MockLlmClient;
269
270    impl MockLlmClient {
271        fn new() -> Self {
272            Self
273        }
274    }
275
276    #[async_trait::async_trait]
277    impl LlmClient for MockLlmClient {
278        async fn generate(
279            &self,
280            _params: GenerateParams,
281        ) -> nexus_llm::Result<nexus_llm::GenerateResponse> {
282            Ok(nexus_llm::GenerateResponse {
283                content: "{}".to_string(),
284                model: "mock-model".to_string(),
285                usage: None,
286            })
287        }
288
289        fn provider_name(&self) -> String {
290            "mock".to_string()
291        }
292
293        fn model_name(&self) -> String {
294            "mock-model".to_string()
295        }
296    }
297}