1use std::sync::Arc;
7
8use nexus_llm::{
9 create_client_auto_with_fallback, ChatMessage, GenerateParams, LlmClient, LlmClientJson,
10};
11use serde::{Deserialize, Serialize};
12use tracing::{debug, info};
13
14use crate::candidate::MemoryCandidate;
15use crate::claude_payload::NormalizedHookEvent;
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct EnrichedMemory {
20 pub store: bool,
22 pub category: String,
24 pub memory_text: String,
26 pub labels: Vec<String>,
28 #[serde(rename = "memory_lane_type")]
30 pub memory_lane_type: Option<String>,
31 pub comment: String,
33 pub confidence: f32,
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct EnrichmentBatchResult {
40 pub memories: Vec<EnrichedMemory>,
41}
42
43const ENRICHMENT_SYSTEM_PROMPT: &str = r#"You are enriching agent hook events into durable memories for a retrieval system.
45
46Decide whether each candidate is worth storing.
47
48Only keep information that is durable, decision-relevant, preference-revealing, specification-bearing, contextual in a useful way, or session-significant.
49
50Allowed categories:
51- general
52- facts
53- preferences
54- context
55- specifications
56- session
57
58For each accepted memory:
59- rewrite the memory into a standalone retrieval-friendly sentence or short paragraph
60- assign exactly one allowed category
61- produce 2-5 labels
62- optionally assign a memory_lane_type from: correction, decision, commitment, insight, learning, confidence, pattern_seed, cross_agent, workflow_note, gap
63- produce a comment explaining why the memory is worth keeping
64
65The comment must be model-authored and should explain why the memory is worth keeping, what retrieval value it has, or how it should be interpreted later.
66
67Reject low-signal operational noise.
68
69Return strict JSON only. No markdown fences."#;
70
71pub struct EnrichmentService {
73 client: Arc<dyn LlmClient>,
74 model_name: String,
75}
76
77impl EnrichmentService {
78 pub fn new() -> anyhow::Result<Self> {
80 let client = create_client_auto_with_fallback()?;
81 let model_name = client.model_name();
82
83 info!("EnrichmentService initialized with model: {}", model_name);
84
85 Ok(Self { client, model_name })
86 }
87
88 pub fn model_name(&self) -> &str {
90 &self.model_name
91 }
92
93 pub async fn enrich_candidates(
98 &self,
99 candidates: &[MemoryCandidate],
100 event: &NormalizedHookEvent,
101 ) -> anyhow::Result<EnrichmentBatchResult> {
102 if candidates.is_empty() {
103 debug!("No candidates to enrich");
104 return Ok(EnrichmentBatchResult {
105 memories: Vec::new(),
106 });
107 }
108
109 let user_payload = self.build_user_payload(candidates, event);
110 let user_message = serde_json::to_string_pretty(&user_payload)?;
111
112 debug!(
113 "Enriching {} candidates with event: {}",
114 candidates.len(),
115 event.event_name
116 );
117
118 let params = GenerateParams {
119 messages: vec![
120 ChatMessage::system(ENRICHMENT_SYSTEM_PROMPT),
121 ChatMessage::user(user_message),
122 ],
123 max_tokens: 4096,
124 temperature: 0.3,
125 json_mode: true,
126 };
127
128 let result: EnrichmentBatchResult = self.client.generate_json(params).await?;
129
130 info!(
131 "Enrichment complete: {} memories returned",
132 result.memories.len()
133 );
134
135 Ok(result)
136 }
137
138 fn build_user_payload(
140 &self,
141 candidates: &[MemoryCandidate],
142 event: &NormalizedHookEvent,
143 ) -> serde_json::Value {
144 let event_obj = serde_json::json!({
146 "agent": event.agent,
147 "event_name": event.event_name,
148 "tool_name": event.tool_name,
149 "session_id": event.session_id,
150 "turn_id": event.turn_id,
151 });
152
153 let candidates_array: Vec<serde_json::Value> = candidates
155 .iter()
156 .map(|c| {
157 serde_json::json!({
158 "candidate_id": c.candidate_id,
159 "signal_score": c.signal_score,
160 "provisional_category": c.provisional_category,
161 "memory_text": c.memory_text,
162 "evidence": c.evidence,
163 "labels": c.labels,
164 })
165 })
166 .collect();
167
168 serde_json::json!({
169 "event": event_obj,
170 "candidates": candidates_array,
171 })
172 }
173}
174
175#[cfg(test)]
176mod tests {
177 use super::*;
178 use chrono::Utc;
179 use serde_json::json;
180
181 #[test]
182 fn test_enriched_memory_serialization() {
183 let memory = EnrichedMemory {
184 store: true,
185 category: "preferences".to_string(),
186 memory_text: "User prefers Rust over C++ for systems programming".to_string(),
187 labels: vec![
188 "rust".to_string(),
189 "cpp".to_string(),
190 "preferences".to_string(),
191 ],
192 memory_lane_type: Some("preference".to_string()),
193 comment: "Clear preference statement affecting future tool selection".to_string(),
194 confidence: 0.9,
195 };
196
197 let serialized = serde_json::to_string(&memory).unwrap();
198 let deserialized: EnrichedMemory = serde_json::from_str(&serialized).unwrap();
199
200 assert!(deserialized.store);
201 assert_eq!(deserialized.category, "preferences");
202 assert_eq!(deserialized.labels.len(), 3);
203 }
204
205 #[test]
206 fn test_enrichment_batch_result_serialization() {
207 let result = EnrichmentBatchResult {
208 memories: vec![EnrichedMemory {
209 store: true,
210 category: "facts".to_string(),
211 memory_text: "Project uses SQLite for persistence".to_string(),
212 labels: vec!["sqlite".to_string(), "database".to_string()],
213 memory_lane_type: None,
214 comment: "Architecture fact".to_string(),
215 confidence: 0.95,
216 }],
217 };
218
219 let serialized = serde_json::to_string(&result).unwrap();
220 let deserialized: EnrichmentBatchResult = serde_json::from_str(&serialized).unwrap();
221
222 assert_eq!(deserialized.memories.len(), 1);
223 assert_eq!(deserialized.memories[0].category, "facts");
224 }
225
226 #[test]
227 fn test_build_user_payload() {
228 let service = EnrichmentService {
229 client: Arc::new(MockLlmClient::new()),
230 model_name: "test-model".to_string(),
231 };
232
233 let candidates = vec![MemoryCandidate {
234 candidate_id: "test-1".to_string(),
235 source_event_name: "test_event".to_string(),
236 source_agent: "test-agent".to_string(),
237 signal_score: 0.8,
238 provisional_category: Some("preferences".to_string()),
239 memory_text: "Test memory".to_string(),
240 evidence: json!({"key": "value"}),
241 labels: vec!["test".to_string()],
242 }];
243
244 let event = NormalizedHookEvent {
245 agent: "claude-code".to_string(),
246 event_name: "test_event".to_string(),
247 observed_at: Utc::now(),
248 session_id: Some("session-123".to_string()),
249 turn_id: Some("turn-456".to_string()),
250 cwd: Some("/home/user/project".to_string()),
251 tool_name: Some("test_tool".to_string()),
252 tool_input: None,
253 tool_response_text: None,
254 assistant_message_text: None,
255 user_message_text: None,
256 raw_payload: json!({}),
257 };
258
259 let payload = service.build_user_payload(&candidates, &event);
260
261 assert_eq!(payload["event"]["agent"], "claude-code");
262 assert_eq!(payload["event"]["event_name"], "test_event");
263 assert_eq!(payload["candidates"].as_array().unwrap().len(), 1);
264 assert_eq!(payload["candidates"][0]["candidate_id"], "test-1");
265 }
266
267 struct MockLlmClient;
269
270 impl MockLlmClient {
271 fn new() -> Self {
272 Self
273 }
274 }
275
276 #[async_trait::async_trait]
277 impl LlmClient for MockLlmClient {
278 async fn generate(
279 &self,
280 _params: GenerateParams,
281 ) -> nexus_llm::Result<nexus_llm::GenerateResponse> {
282 Ok(nexus_llm::GenerateResponse {
283 content: "{}".to_string(),
284 model: "mock-model".to_string(),
285 usage: None,
286 })
287 }
288
289 fn provider_name(&self) -> String {
290 "mock".to_string()
291 }
292
293 fn model_name(&self) -> String {
294 "mock-model".to_string()
295 }
296 }
297}