1use std::collections::HashSet;
2
3use crate::memory::semantic::MemoryEntry;
4use crate::memory::trace_analyzer::{InsightKind, TraceInsight};
5
6#[derive(Debug, Clone)]
8pub enum ConflictResolution {
9 PreferNewer,
11 PreferHigherConfidence,
13}
14
15#[derive(Debug, Clone)]
16pub struct CurationPolicy {
17 pub similarity_threshold: f64,
19 pub max_entries: usize,
21 pub conflict_resolution: ConflictResolution,
22 pub min_confidence: f64,
24}
25
26impl Default for CurationPolicy {
27 fn default() -> Self {
28 Self {
29 similarity_threshold: 0.65,
30 max_entries: 500,
31 conflict_resolution: ConflictResolution::PreferNewer,
32 min_confidence: 0.3,
33 }
34 }
35}
36
37#[derive(Debug, Clone, Default)]
38pub struct CurationStats {
39 pub insights_processed: usize,
40 pub duplicates_removed: usize,
41 pub conflicts_resolved: usize,
42 pub entries_added: usize,
43}
44
45#[derive(Debug, Clone)]
47pub struct CurationResult {
48 pub to_add: Vec<MemoryEntry>,
49 pub to_remove_indices: Vec<usize>,
51 pub stats: CurationStats,
52}
53
54pub struct MemoryCurator {
55 pub policy: CurationPolicy,
56}
57
58impl MemoryCurator {
59 pub fn new(policy: CurationPolicy) -> Self {
60 Self { policy }
61 }
62
63 pub fn curate(
67 &self,
68 insights: &[TraceInsight],
69 existing: &[MemoryEntry],
70 now_ms: u64,
71 ) -> CurationResult {
72 let mut stats = CurationStats {
73 insights_processed: insights.len(),
74 ..Default::default()
75 };
76 let mut to_add: Vec<MemoryEntry> = Vec::new();
77 let mut to_remove_indices: Vec<usize> = Vec::new();
78
79 for insight in insights {
80 if insight.confidence < self.policy.min_confidence {
81 continue;
82 }
83
84 let candidate = insight_to_entry(insight, now_ms);
85
86 let mut conflict_idx: Option<usize> = None;
88 for (idx, existing_entry) in existing.iter().enumerate() {
89 if to_remove_indices.contains(&idx) {
90 continue; }
92 if jaccard(&candidate.text, &existing_entry.text)
93 >= self.policy.similarity_threshold
94 {
95 conflict_idx = Some(idx);
96 break;
97 }
98 }
99
100 if let Some(idx) = conflict_idx {
101 let existing_entry = &existing[idx];
102 let keep_new = match self.policy.conflict_resolution {
103 ConflictResolution::PreferNewer => true,
104 ConflictResolution::PreferHigherConfidence => {
105 candidate.score >= existing_entry.score
106 }
107 };
108 if keep_new {
109 to_remove_indices.push(idx);
110 stats.conflicts_resolved += 1;
111 } else {
112 stats.duplicates_removed += 1;
113 continue;
114 }
115 }
116
117 let dup_in_batch = to_add
119 .iter()
120 .any(|e| jaccard(&candidate.text, &e.text) >= self.policy.similarity_threshold);
121 if dup_in_batch {
122 stats.duplicates_removed += 1;
123 continue;
124 }
125
126 to_add.push(candidate);
127 stats.entries_added += 1;
128 }
129
130 to_remove_indices.sort_unstable();
131 to_remove_indices.dedup();
132
133 let surviving_existing = existing.len().saturating_sub(to_remove_indices.len());
135 let headroom = self.policy.max_entries.saturating_sub(surviving_existing);
136 to_add.truncate(headroom);
137 stats.entries_added = to_add.len();
138
139 CurationResult {
140 to_add,
141 to_remove_indices,
142 stats,
143 }
144 }
145}
146
147fn insight_to_entry(insight: &TraceInsight, now_ms: u64) -> MemoryEntry {
150 let text = match &insight.kind {
151 InsightKind::RepeatedToolError {
152 tool_name,
153 error_count,
154 sample_error,
155 } => {
156 format!(
157 "Tool '{}' failed {} times; pattern: {}",
158 tool_name, error_count, sample_error
159 )
160 }
161 InsightKind::SuccessfulToolSequence {
162 tools,
163 context_hint,
164 } => {
165 format!(
166 "Successful sequence [{}] for: {}",
167 tools.join(" → "),
168 context_hint
169 )
170 }
171 InsightKind::LongReasoning { summary_hint } => summary_hint.clone(),
172 InsightKind::Synthesized { text } => text.clone(),
173 };
174 let metadata = serde_json::json!({
175 "kind": insight.kind.tag(),
176 "confidence": insight.confidence,
177 "session_id": insight.session_id,
178 "extracted_at_ms": now_ms,
179 });
180 MemoryEntry {
181 text,
182 score: insight.confidence,
183 metadata,
184 }
185}
186
187fn jaccard(a: &str, b: &str) -> f64 {
188 let sa: HashSet<&str> = a.split_whitespace().collect();
189 let sb: HashSet<&str> = b.split_whitespace().collect();
190 let inter = sa.intersection(&sb).count();
191 let union = sa.union(&sb).count();
192 if union == 0 {
193 0.0
194 } else {
195 inter as f64 / union as f64
196 }
197}
198
199#[cfg(test)]
200mod tests {
201 use super::*;
202 use crate::memory::trace_analyzer::{AnalysisPolicy, InsightKind, TraceAnalyzer, TraceInsight};
203 use pretty_assertions::assert_eq;
204
205 fn curator() -> MemoryCurator {
206 MemoryCurator::new(CurationPolicy::default())
207 }
208
209 fn error_insight(tool: &str, confidence: f64) -> TraceInsight {
210 TraceInsight {
211 kind: InsightKind::RepeatedToolError {
212 tool_name: tool.to_string(),
213 error_count: 3,
214 sample_error: "permission denied".to_string(),
215 },
216 confidence,
217 session_id: "s1".to_string(),
218 }
219 }
220
221 fn existing_entry(text: &str, score: f64) -> MemoryEntry {
222 MemoryEntry {
223 text: text.to_string(),
224 score,
225 metadata: serde_json::Value::Null,
226 }
227 }
228
229 #[test]
230 fn adds_new_insights_when_no_existing() {
231 let result = curator().curate(&[error_insight("bash", 0.8)], &[], 0);
232 assert_eq!(result.to_add.len(), 1);
233 assert!(result.to_remove_indices.is_empty());
234 assert_eq!(result.stats.entries_added, 1);
235 }
236
237 #[test]
238 fn skips_low_confidence_insights() {
239 let result = curator().curate(&[error_insight("bash", 0.1)], &[], 0);
241 assert!(result.to_add.is_empty());
242 assert_eq!(result.stats.entries_added, 0);
243 }
244
245 #[test]
246 fn prefer_newer_replaces_similar_existing() {
247 let existing = vec![existing_entry(
248 "Tool 'bash' failed 2 times; pattern: permission denied",
249 0.4,
250 )];
251 let result = curator().curate(&[error_insight("bash", 0.8)], &existing, 1000);
252 assert_eq!(result.to_add.len(), 1);
253 assert_eq!(result.to_remove_indices, vec![0]);
254 assert_eq!(result.stats.conflicts_resolved, 1);
255 }
256
257 #[test]
258 fn prefer_higher_confidence_keeps_existing_when_better() {
259 let policy = CurationPolicy {
260 conflict_resolution: ConflictResolution::PreferHigherConfidence,
261 ..Default::default()
262 };
263 let curator = MemoryCurator::new(policy);
264 let existing = vec![existing_entry(
265 "Tool 'bash' failed 3 times; pattern: permission denied",
266 0.95,
267 )];
268 let result = curator.curate(&[error_insight("bash", 0.5)], &existing, 0);
270 assert!(result.to_add.is_empty());
271 assert!(result.to_remove_indices.is_empty());
272 assert_eq!(result.stats.duplicates_removed, 1);
273 }
274
275 #[test]
276 fn deduplicates_within_batch() {
277 let insights = vec![error_insight("bash", 0.8), error_insight("bash", 0.7)];
279 let result = curator().curate(&insights, &[], 0);
280 assert_eq!(result.to_add.len(), 1);
281 assert_eq!(result.stats.duplicates_removed, 1);
282 }
283
284 #[test]
285 fn respects_max_entries_headroom() {
286 let policy = CurationPolicy {
287 max_entries: 2,
288 ..Default::default()
289 };
290 let curator = MemoryCurator::new(policy);
291 let existing = vec![
292 existing_entry("unrelated entry one", 0.5),
293 existing_entry("unrelated entry two", 0.5),
294 ];
295 let insights = vec![error_insight("bash", 0.8)];
297 let result = curator.curate(&insights, &existing, 0);
298 assert!(result.to_add.is_empty());
299 }
300
301 #[test]
302 fn end_to_end_with_trace_analyzer() {
303 use crate::types::message::{ContentPart, ToolCall};
304 use compact_str::CompactString;
305
306 let mut call_msg = crate::types::message::Message::assistant("");
307 call_msg.tool_calls = vec![
308 ToolCall {
309 id: CompactString::new("c1"),
310 name: CompactString::new("bash"),
311 arguments: serde_json::Value::Null,
312 },
313 ToolCall {
314 id: CompactString::new("c2"),
315 name: CompactString::new("bash"),
316 arguments: serde_json::Value::Null,
317 },
318 ];
319 let err_msg1 = crate::types::message::Message::tool(vec![ContentPart::ToolResult {
320 call_id: CompactString::new("c1"),
321 output: "permission denied".to_string(),
322 is_error: true,
323 }]);
324 let err_msg2 = crate::types::message::Message::tool(vec![ContentPart::ToolResult {
325 call_id: CompactString::new("c2"),
326 output: "permission denied".to_string(),
327 is_error: true,
328 }]);
329
330 let messages = vec![call_msg, err_msg1, err_msg2];
331 let analyzer = TraceAnalyzer::new(AnalysisPolicy::default());
332 let insights = analyzer.analyze("s1", &messages);
333 assert!(!insights.is_empty());
334
335 let result = curator().curate(&insights, &[], 42_000);
336 assert!(!result.to_add.is_empty());
337 assert!(
338 result.to_add[0].metadata["kind"] == "repeated_tool_error"
339 || result.to_add[0].metadata["kind"] == "synthesized"
340 );
341 assert_eq!(result.to_add[0].metadata["extracted_at_ms"], 42_000);
342 }
343}