harness_context/
memory_guard.rs1use async_trait::async_trait;
32use harness_core::{Memory, MemoryEntry, MemoryError};
33use regex::Regex;
34use std::collections::HashSet;
35use std::sync::Arc;
36
37pub struct GuardedMemory {
40 inner: Arc<dyn Memory>,
41 sensitivity_patterns: Vec<Regex>,
42 blocked_substrings: Vec<String>,
43 dedup_threshold: f32,
44 dedup_recall_k: usize,
45}
46
47impl GuardedMemory {
48 pub fn new(inner: Arc<dyn Memory>) -> Self {
52 Self {
53 inner,
54 sensitivity_patterns: default_sensitivity_patterns(),
55 blocked_substrings: Vec::new(),
56 dedup_threshold: 0.6,
57 dedup_recall_k: 5,
58 }
59 }
60
61 pub fn without_default_sensitivity(mut self) -> Self {
64 self.sensitivity_patterns.clear();
65 self
66 }
67
68 pub fn with_sensitivity_pattern(mut self, pat: impl AsRef<str>) -> Result<Self, regex::Error> {
71 self.sensitivity_patterns.push(Regex::new(pat.as_ref())?);
72 Ok(self)
73 }
74
75 pub fn with_blocked_substring(mut self, s: impl Into<String>) -> Self {
79 self.blocked_substrings.push(s.into().to_lowercase());
80 self
81 }
82
83 pub fn with_dedup_threshold(mut self, t: f32) -> Self {
87 self.dedup_threshold = t.clamp(0.0, 1.0);
88 self
89 }
90
91 pub fn with_dedup_recall_k(mut self, k: usize) -> Self {
95 self.dedup_recall_k = k.max(1);
96 self
97 }
98
99 fn is_sensitive(&self, content: &str) -> bool {
100 let lower = content.to_lowercase();
101 if self.blocked_substrings.iter().any(|s| lower.contains(s)) {
102 return true;
103 }
104 self.sensitivity_patterns
105 .iter()
106 .any(|r| r.is_match(content))
107 }
108
109 async fn is_duplicate(&self, entry: &MemoryEntry) -> bool {
110 if self.dedup_threshold <= 0.0 {
111 return false;
112 }
113 let cands = match self.inner.recall(&entry.content, self.dedup_recall_k).await {
114 Ok(v) => v,
115 Err(_) => return false,
116 };
117 let new_tokens = jaccard_tokens(&entry.content);
118 if new_tokens.is_empty() {
119 return false;
120 }
121 for c in cands {
122 let cand_tokens = jaccard_tokens(&c.content);
123 if jaccard(&new_tokens, &cand_tokens) >= self.dedup_threshold {
124 return true;
125 }
126 }
127 false
128 }
129}
130
131#[async_trait]
132impl Memory for GuardedMemory {
133 async fn recall(&self, query: &str, k: usize) -> Result<Vec<MemoryEntry>, MemoryError> {
134 self.inner.recall(query, k).await
135 }
136
137 async fn write(&self, entry: MemoryEntry) -> Result<(), MemoryError> {
138 if self.is_sensitive(&entry.content) {
139 tracing::info!(
140 content_preview = %entry.content.chars().take(40).collect::<String>(),
141 "guarded memory: dropping sensitive entry"
142 );
143 return Ok(());
144 }
145 if self.is_duplicate(&entry).await {
146 tracing::info!(
147 content_preview = %entry.content.chars().take(40).collect::<String>(),
148 "guarded memory: dropping duplicate entry"
149 );
150 return Ok(());
151 }
152 self.inner.write(entry).await
153 }
154}
155
156fn default_sensitivity_patterns() -> Vec<Regex> {
157 [
158 r"\b\d{13,19}\b",
160 r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}\b",
162 r"\b1[3-9]\d{9}\b",
164 r"[¥$€£₹]\s?\d+(?:[.,]\d+)?",
168 r"\b(?:USD|CNY|EUR|RMB|HKD|JPY)\s?\d+(?:[.,]\d+)?\b",
169 ]
170 .iter()
171 .filter_map(|p| Regex::new(p).ok())
172 .collect()
173}
174
175fn jaccard_tokens(s: &str) -> HashSet<String> {
176 s.to_lowercase()
177 .split(|c: char| !c.is_alphanumeric())
178 .filter(|t| t.len() >= 3)
179 .map(String::from)
180 .collect()
181}
182
183fn jaccard(a: &HashSet<String>, b: &HashSet<String>) -> f32 {
184 if a.is_empty() || b.is_empty() {
185 return 0.0;
186 }
187 let inter = a.intersection(b).count() as f32;
188 let union = a.union(b).count() as f32;
189 if union == 0.0 { 0.0 } else { inter / union }
190}
191
192#[cfg(test)]
193mod tests {
194 use super::*;
195 use harness_core::Memory;
196 use std::sync::Mutex;
197
198 #[derive(Default)]
199 struct VecMemory {
200 store: Mutex<Vec<MemoryEntry>>,
201 }
202 #[async_trait]
203 impl Memory for VecMemory {
204 async fn recall(&self, query: &str, k: usize) -> Result<Vec<MemoryEntry>, MemoryError> {
205 let g = self.store.lock().unwrap();
209 let q_tokens = jaccard_tokens(query);
210 if q_tokens.is_empty() {
211 return Ok(g.iter().take(k).cloned().collect());
212 }
213 let mut scored: Vec<(u32, &MemoryEntry)> = g
214 .iter()
215 .map(|e| {
216 let hay = e.content.to_lowercase();
217 let hits: u32 = q_tokens
218 .iter()
219 .map(|t| if hay.contains(t.as_str()) { 1 } else { 0 })
220 .sum();
221 (hits, e)
222 })
223 .filter(|(hits, _)| *hits > 0)
224 .collect();
225 scored.sort_by(|a, b| b.0.cmp(&a.0));
226 Ok(scored.into_iter().take(k).map(|(_, e)| e.clone()).collect())
227 }
228 async fn write(&self, entry: MemoryEntry) -> Result<(), MemoryError> {
229 self.store.lock().unwrap().push(entry);
230 Ok(())
231 }
232 }
233
234 #[tokio::test]
235 async fn sensitive_credit_card_is_dropped() {
236 let inner: Arc<dyn Memory> = Arc::new(VecMemory::default());
237 let mem = GuardedMemory::new(inner.clone());
238 mem.write(MemoryEntry::new(
239 "user's card is 4111111111111111 expiry 12/30",
240 ))
241 .await
242 .unwrap();
243 let all = inner.recall("card", 10).await.unwrap();
244 assert!(all.is_empty(), "credit-card-like content should be dropped");
245 }
246
247 #[tokio::test]
248 async fn sensitive_email_is_dropped() {
249 let inner: Arc<dyn Memory> = Arc::new(VecMemory::default());
250 let mem = GuardedMemory::new(inner.clone());
251 mem.write(MemoryEntry::new("user's email is ll_faw@hotmail.com"))
252 .await
253 .unwrap();
254 let all = inner.recall("email", 10).await.unwrap();
255 assert!(all.is_empty());
256 }
257
258 #[tokio::test]
259 async fn monetary_amounts_are_dropped() {
260 let inner: Arc<dyn Memory> = Arc::new(VecMemory::default());
261 let mem = GuardedMemory::new(inner.clone());
262 mem.write(MemoryEntry::new("用户记录了一笔 ¥199 火锅消费"))
263 .await
264 .unwrap();
265 mem.write(MemoryEntry::new("user spent USD 250 on Claude Code"))
266 .await
267 .unwrap();
268 let all = inner.recall("user", 10).await.unwrap();
269 assert!(
270 all.is_empty(),
271 "monetary patterns should be filtered: {all:?}"
272 );
273 }
274
275 #[tokio::test]
276 async fn durable_preferences_pass_through() {
277 let inner: Arc<dyn Memory> = Arc::new(VecMemory::default());
278 let mem = GuardedMemory::new(inner.clone());
279 mem.write(MemoryEntry::new("用户偏好使用微信支付餐饮类支出"))
280 .await
281 .unwrap();
282 mem.write(MemoryEntry::new(
283 "user prefers concise replies in Slack style",
284 ))
285 .await
286 .unwrap();
287 let all = inner.recall("用户", 10).await.unwrap();
288 assert_eq!(all.len(), 1, "preference about 用户 should be kept");
289 }
290
291 #[tokio::test]
292 async fn duplicate_is_dropped() {
293 let inner: Arc<dyn Memory> = Arc::new(VecMemory::default());
294 let mem = GuardedMemory::new(inner.clone()).with_dedup_threshold(0.6);
295 mem.write(MemoryEntry::new(
296 "user prefers concise replies written in Slack style",
297 ))
298 .await
299 .unwrap();
300 mem.write(MemoryEntry::new(
302 "user prefers concise replies in Slack tone",
303 ))
304 .await
305 .unwrap();
306 let all = inner.recall("user", 10).await.unwrap();
307 assert_eq!(
308 all.len(),
309 1,
310 "near-duplicate should not double-store: {all:?}"
311 );
312 }
313
314 #[tokio::test]
315 async fn blocked_substring_works() {
316 let inner: Arc<dyn Memory> = Arc::new(VecMemory::default());
317 let mem = GuardedMemory::new(inner.clone()).with_blocked_substring("password");
318 mem.write(MemoryEntry::new("user's password reset is hunter2"))
319 .await
320 .unwrap();
321 let all = inner.recall("password", 10).await.unwrap();
322 assert!(all.is_empty());
323 }
324}