1use std::sync::Arc;
7
8use nexus_llm::{
9 create_client_auto_with_fallback, ChatMessage, GenerateParams, LlmClient, LlmClientJson,
10};
11use serde::{Deserialize, Serialize};
12use tracing::{debug, info};
13
14use crate::candidate::MemoryCandidate;
15use crate::claude_payload::NormalizedHookEvent;
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct EnrichedMemory {
20 pub store: bool,
22 pub category: String,
24 #[serde(alias = "memory")]
26 pub memory_text: String,
27 pub labels: Vec<String>,
29 #[serde(rename = "memory_lane_type")]
31 pub memory_lane_type: Option<String>,
32 pub comment: String,
34 pub confidence: f32,
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
40pub struct EnrichmentBatchResult {
41 #[serde(alias = "accepted_memories", default)]
42 pub memories: Vec<EnrichedMemory>,
43}
44
45const ENRICHMENT_SYSTEM_PROMPT: &str = r#"You are enriching agent hook events into durable memories for a retrieval system.
47
48Decide whether each candidate is worth storing.
49
50Only keep information that is durable, decision-relevant, preference-revealing, specification-bearing, contextual in a useful way, or session-significant.
51
52Allowed categories:
53- general
54- facts
55- preferences
56- context
57- specifications
58- session
59
60For each accepted memory:
61- rewrite the memory into a standalone retrieval-friendly sentence or short paragraph
62- assign exactly one allowed category
63- produce 2-5 labels
64- optionally assign a memory_lane_type from: correction, decision, commitment, insight, learning, confidence, pattern_seed, cross_agent, workflow_note, gap
65- produce a comment explaining why the memory is worth keeping
66
67The comment must be model-authored and should explain why the memory is worth keeping, what retrieval value it has, or how it should be interpreted later.
68
69Reject low-signal operational noise.
70
71Return strict JSON only. No markdown fences."#;
72
73pub struct EnrichmentService {
75 client: Arc<dyn LlmClient>,
76 model_name: String,
77}
78
79impl EnrichmentService {
80 pub fn new() -> anyhow::Result<Self> {
82 let client = create_client_auto_with_fallback()?;
83 let model_name = client.model_name();
84
85 info!("EnrichmentService initialized with model: {}", model_name);
86
87 Ok(Self { client, model_name })
88 }
89
90 pub fn model_name(&self) -> &str {
92 &self.model_name
93 }
94
95 pub async fn enrich_candidates(
100 &self,
101 candidates: &[MemoryCandidate],
102 event: &NormalizedHookEvent,
103 ) -> anyhow::Result<EnrichmentBatchResult> {
104 if candidates.is_empty() {
105 debug!("No candidates to enrich");
106 return Ok(EnrichmentBatchResult {
107 memories: Vec::new(),
108 });
109 }
110
111 let user_payload = self.build_user_payload(candidates, event);
112 let user_message = serde_json::to_string_pretty(&user_payload)?;
113
114 debug!(
115 "Enriching {} candidates with event: {}",
116 candidates.len(),
117 event.event_name
118 );
119
120 let params = GenerateParams {
121 messages: vec![
122 ChatMessage::system(ENRICHMENT_SYSTEM_PROMPT),
123 ChatMessage::user(user_message),
124 ],
125 max_tokens: 4096,
126 temperature: 0.3,
127 json_mode: true,
128 };
129
130 let result: EnrichmentBatchResult = self.client.generate_json(params).await?;
131
132 info!(
133 "Enrichment complete: {} memories returned",
134 result.memories.len()
135 );
136
137 Ok(result)
138 }
139
140 fn build_user_payload(
142 &self,
143 candidates: &[MemoryCandidate],
144 event: &NormalizedHookEvent,
145 ) -> serde_json::Value {
146 let event_obj = serde_json::json!({
148 "agent": event.agent,
149 "event_name": event.event_name,
150 "tool_name": event.tool_name,
151 "session_id": event.session_id,
152 "turn_id": event.turn_id,
153 });
154
155 let candidates_array: Vec<serde_json::Value> = candidates
157 .iter()
158 .map(|c| {
159 serde_json::json!({
160 "candidate_id": c.candidate_id,
161 "signal_score": c.signal_score,
162 "provisional_category": c.provisional_category,
163 "memory_text": c.memory_text,
164 "evidence": c.evidence,
165 "labels": c.labels,
166 })
167 })
168 .collect();
169
170 serde_json::json!({
171 "event": event_obj,
172 "candidates": candidates_array,
173 })
174 }
175}
176
177#[cfg(test)]
178mod tests {
179 use super::*;
180 use chrono::Utc;
181 use serde_json::json;
182
183 #[test]
184 fn test_enriched_memory_serialization() {
185 let memory = EnrichedMemory {
186 store: true,
187 category: "preferences".to_string(),
188 memory_text: "User prefers Rust over C++ for systems programming".to_string(),
189 labels: vec![
190 "rust".to_string(),
191 "cpp".to_string(),
192 "preferences".to_string(),
193 ],
194 memory_lane_type: Some("preference".to_string()),
195 comment: "Clear preference statement affecting future tool selection".to_string(),
196 confidence: 0.9,
197 };
198
199 let serialized = serde_json::to_string(&memory).unwrap();
200 let deserialized: EnrichedMemory = serde_json::from_str(&serialized).unwrap();
201
202 assert!(deserialized.store);
203 assert_eq!(deserialized.category, "preferences");
204 assert_eq!(deserialized.labels.len(), 3);
205 }
206
207 #[test]
208 fn test_enrichment_batch_result_serialization() {
209 let result = EnrichmentBatchResult {
210 memories: vec![EnrichedMemory {
211 store: true,
212 category: "facts".to_string(),
213 memory_text: "Project uses SQLite for persistence".to_string(),
214 labels: vec!["sqlite".to_string(), "database".to_string()],
215 memory_lane_type: None,
216 comment: "Architecture fact".to_string(),
217 confidence: 0.95,
218 }],
219 };
220
221 let serialized = serde_json::to_string(&result).unwrap();
222 let deserialized: EnrichmentBatchResult = serde_json::from_str(&serialized).unwrap();
223
224 assert_eq!(deserialized.memories.len(), 1);
225 assert_eq!(deserialized.memories[0].category, "facts");
226 }
227
228 #[test]
229 fn test_enrichment_batch_result_accepts_accepted_memories_alias() {
230 let payload = json!({
231 "accepted_memories": [
232 {
233 "store": true,
234 "category": "facts",
235 "memory_text": "Project uses SQLite for persistence",
236 "labels": ["sqlite", "database"],
237 "memory_lane_type": null,
238 "comment": "Architecture fact",
239 "confidence": 0.95
240 }
241 ]
242 });
243
244 let deserialized: EnrichmentBatchResult =
245 serde_json::from_value(payload).expect("deserialize alias payload");
246
247 assert_eq!(deserialized.memories.len(), 1);
248 assert_eq!(deserialized.memories[0].category, "facts");
249 }
250
251 #[test]
252 fn test_build_user_payload() {
253 let service = EnrichmentService {
254 client: Arc::new(MockLlmClient::new()),
255 model_name: "test-model".to_string(),
256 };
257
258 let candidates = vec![MemoryCandidate {
259 candidate_id: "test-1".to_string(),
260 source_event_name: "test_event".to_string(),
261 source_agent: "test-agent".to_string(),
262 signal_score: 0.8,
263 provisional_category: Some("preferences".to_string()),
264 memory_text: "Test memory".to_string(),
265 evidence: json!({"key": "value"}),
266 labels: vec!["test".to_string()],
267 }];
268
269 let event = NormalizedHookEvent {
270 agent: "claude-code".to_string(),
271 event_name: "test_event".to_string(),
272 observed_at: Utc::now(),
273 session_id: Some("session-123".to_string()),
274 turn_id: Some("turn-456".to_string()),
275 cwd: Some("/home/user/project".to_string()),
276 tool_name: Some("test_tool".to_string()),
277 tool_input: None,
278 tool_response_text: None,
279 assistant_message_text: None,
280 user_message_text: None,
281 raw_payload: json!({}),
282 };
283
284 let payload = service.build_user_payload(&candidates, &event);
285
286 assert_eq!(payload["event"]["agent"], "claude-code");
287 assert_eq!(payload["event"]["event_name"], "test_event");
288 assert_eq!(payload["candidates"].as_array().unwrap().len(), 1);
289 assert_eq!(payload["candidates"][0]["candidate_id"], "test-1");
290 }
291
292 struct MockLlmClient;
294
295 impl MockLlmClient {
296 fn new() -> Self {
297 Self
298 }
299 }
300
301 #[async_trait::async_trait]
302 impl LlmClient for MockLlmClient {
303 async fn generate(
304 &self,
305 _params: GenerateParams,
306 ) -> nexus_llm::Result<nexus_llm::GenerateResponse> {
307 Ok(nexus_llm::GenerateResponse {
308 content: "{}".to_string(),
309 model: "mock-model".to_string(),
310 usage: None,
311 })
312 }
313
314 fn provider_name(&self) -> String {
315 "mock".to_string()
316 }
317
318 fn model_name(&self) -> String {
319 "mock-model".to_string()
320 }
321 }
322}