Skip to main content

harness_context/
memory_guard.rs

1//! Memory write-time guards: dedup + sensitivity filter.
2//!
3//! `GuardedMemory` wraps any `Arc<dyn Memory>` and runs two cheap checks
4//! before letting `write` through to the inner store:
5//!
6//! 1. **Sensitivity gate** — regex match against a configurable set of
7//!    patterns (credit card numbers, emails, phone, monetary amounts). If
8//!    the entry content trips any pattern, the write is silently dropped
9//!    (logged at info level). Defaults aim for "do no harm in personal /
10//!    financial contexts"; turn them off with `.without_default_sensitivity()`.
11//!
12//! 2. **Dedup** — calls `inner.recall(entry.content, 5)` and compares each
13//!    candidate's content tokens against the new entry's tokens. If the
14//!    Jaccard similarity exceeds `dedup_threshold` (default 0.6) for ANY
15//!    candidate, the write is dropped — the existing entry already covers
16//!    this fact, no need to inflate the file.
17//!
18//! `recall` and the underlying file ops are pass-through.
19//!
20//! Layered design — apply on top of `FileMemory` (or any other backend):
21//!
22//! ```ignore
23//! let file_mem = FileMemory::open(path)?;
24//! let memory: Arc<dyn Memory> = Arc::new(
25//!     GuardedMemory::new(Arc::new(file_mem))
26//!         .with_blocked_substring("password")
27//!         .with_dedup_threshold(0.55)
28//! );
29//! ```
30
31use async_trait::async_trait;
32use harness_core::{Memory, MemoryEntry, MemoryError};
33use regex::Regex;
34use std::collections::HashSet;
35use std::sync::Arc;
36
37/// Wraps any `Arc<dyn Memory>` and adds dedup + sensitivity filtering on
38/// `write`. `recall` is pass-through.
39pub struct GuardedMemory {
40    inner: Arc<dyn Memory>,
41    sensitivity_patterns: Vec<Regex>,
42    blocked_substrings: Vec<String>,
43    dedup_threshold: f32,
44    dedup_recall_k: usize,
45}
46
47impl GuardedMemory {
48    /// Wrap `inner` with the default sensitivity patterns (credit-card-like
49    /// 13-19 digit runs, emails, common monetary patterns) and a dedup
50    /// threshold of 0.6 Jaccard token overlap.
51    pub fn new(inner: Arc<dyn Memory>) -> Self {
52        Self {
53            inner,
54            sensitivity_patterns: default_sensitivity_patterns(),
55            blocked_substrings: Vec::new(),
56            dedup_threshold: 0.6,
57            dedup_recall_k: 5,
58        }
59    }
60
61    /// Drop the default sensitivity patterns — useful for tests or when
62    /// callers know they're storing already-redacted content.
63    pub fn without_default_sensitivity(mut self) -> Self {
64        self.sensitivity_patterns.clear();
65        self
66    }
67
68    /// Add an extra regex pattern. If the entry content matches ANY
69    /// pattern, the write is dropped.
70    pub fn with_sensitivity_pattern(mut self, pat: impl AsRef<str>) -> Result<Self, regex::Error> {
71        self.sensitivity_patterns.push(Regex::new(pat.as_ref())?);
72        Ok(self)
73    }
74
75    /// Add a literal substring to the block-list (case-insensitive contains).
76    /// Cheaper than a regex; use for app-specific terms that should never
77    /// hit memory (e.g. `"password"`, `"内部秘钥"`).
78    pub fn with_blocked_substring(mut self, s: impl Into<String>) -> Self {
79        self.blocked_substrings.push(s.into().to_lowercase());
80        self
81    }
82
83    /// Override the Jaccard token-overlap threshold above which an entry is
84    /// considered a duplicate of an existing one. Range [0.0, 1.0]; default
85    /// 0.6. Set to 1.0 to require exact match, 0.0 to disable dedup.
86    pub fn with_dedup_threshold(mut self, t: f32) -> Self {
87        self.dedup_threshold = t.clamp(0.0, 1.0);
88        self
89    }
90
91    /// How many candidates to fetch from `recall` for dedup comparison.
92    /// Default 5. Increase if your store gets large and recall miss rate
93    /// is high; decrease for tiny stores.
94    pub fn with_dedup_recall_k(mut self, k: usize) -> Self {
95        self.dedup_recall_k = k.max(1);
96        self
97    }
98
99    fn is_sensitive(&self, content: &str) -> bool {
100        let lower = content.to_lowercase();
101        if self.blocked_substrings.iter().any(|s| lower.contains(s)) {
102            return true;
103        }
104        self.sensitivity_patterns
105            .iter()
106            .any(|r| r.is_match(content))
107    }
108
109    async fn is_duplicate(&self, entry: &MemoryEntry) -> bool {
110        if self.dedup_threshold <= 0.0 {
111            return false;
112        }
113        let cands = match self.inner.recall(&entry.content, self.dedup_recall_k).await {
114            Ok(v) => v,
115            Err(_) => return false,
116        };
117        let new_tokens = jaccard_tokens(&entry.content);
118        if new_tokens.is_empty() {
119            return false;
120        }
121        for c in cands {
122            let cand_tokens = jaccard_tokens(&c.content);
123            if jaccard(&new_tokens, &cand_tokens) >= self.dedup_threshold {
124                return true;
125            }
126        }
127        false
128    }
129}
130
131#[async_trait]
132impl Memory for GuardedMemory {
133    async fn recall(&self, query: &str, k: usize) -> Result<Vec<MemoryEntry>, MemoryError> {
134        self.inner.recall(query, k).await
135    }
136
137    async fn write(&self, entry: MemoryEntry) -> Result<(), MemoryError> {
138        if self.is_sensitive(&entry.content) {
139            tracing::info!(
140                content_preview = %entry.content.chars().take(40).collect::<String>(),
141                "guarded memory: dropping sensitive entry"
142            );
143            return Ok(());
144        }
145        if self.is_duplicate(&entry).await {
146            tracing::info!(
147                content_preview = %entry.content.chars().take(40).collect::<String>(),
148                "guarded memory: dropping duplicate entry"
149            );
150            return Ok(());
151        }
152        self.inner.write(entry).await
153    }
154}
155
156fn default_sensitivity_patterns() -> Vec<Regex> {
157    [
158        // 13-19 consecutive digits (credit card-ish — covers most PANs)
159        r"\b\d{13,19}\b",
160        // Emails
161        r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}\b",
162        // Chinese mainland mobile (1 + 10 digits)
163        r"\b1[3-9]\d{9}\b",
164        // Common monetary amount mentions — ¥1234.56, $1234, USD 1234,
165        // CNY 1234.56 (account flows that should live in the txns table,
166        // not in long-term memory)
167        r"[¥$€£₹]\s?\d+(?:[.,]\d+)?",
168        r"\b(?:USD|CNY|EUR|RMB|HKD|JPY)\s?\d+(?:[.,]\d+)?\b",
169    ]
170    .iter()
171    .filter_map(|p| Regex::new(p).ok())
172    .collect()
173}
174
175fn jaccard_tokens(s: &str) -> HashSet<String> {
176    s.to_lowercase()
177        .split(|c: char| !c.is_alphanumeric())
178        .filter(|t| t.len() >= 3)
179        .map(String::from)
180        .collect()
181}
182
183fn jaccard(a: &HashSet<String>, b: &HashSet<String>) -> f32 {
184    if a.is_empty() || b.is_empty() {
185        return 0.0;
186    }
187    let inter = a.intersection(b).count() as f32;
188    let union = a.union(b).count() as f32;
189    if union == 0.0 { 0.0 } else { inter / union }
190}
191
192#[cfg(test)]
193mod tests {
194    use super::*;
195    use harness_core::Memory;
196    use std::sync::Mutex;
197
198    #[derive(Default)]
199    struct VecMemory {
200        store: Mutex<Vec<MemoryEntry>>,
201    }
202    #[async_trait]
203    impl Memory for VecMemory {
204        async fn recall(&self, query: &str, k: usize) -> Result<Vec<MemoryEntry>, MemoryError> {
205            // Mimic FileMemory: substring-contains scoring against lowercased
206            // content + tags. Plain Jaccard exact-token wouldn't substring-
207            // match CJK content (where the whole string is one big token).
208            let g = self.store.lock().unwrap();
209            let q_tokens = jaccard_tokens(query);
210            if q_tokens.is_empty() {
211                return Ok(g.iter().take(k).cloned().collect());
212            }
213            let mut scored: Vec<(u32, &MemoryEntry)> = g
214                .iter()
215                .map(|e| {
216                    let hay = e.content.to_lowercase();
217                    let hits: u32 = q_tokens
218                        .iter()
219                        .map(|t| if hay.contains(t.as_str()) { 1 } else { 0 })
220                        .sum();
221                    (hits, e)
222                })
223                .filter(|(hits, _)| *hits > 0)
224                .collect();
225            scored.sort_by(|a, b| b.0.cmp(&a.0));
226            Ok(scored.into_iter().take(k).map(|(_, e)| e.clone()).collect())
227        }
228        async fn write(&self, entry: MemoryEntry) -> Result<(), MemoryError> {
229            self.store.lock().unwrap().push(entry);
230            Ok(())
231        }
232    }
233
234    #[tokio::test]
235    async fn sensitive_credit_card_is_dropped() {
236        let inner: Arc<dyn Memory> = Arc::new(VecMemory::default());
237        let mem = GuardedMemory::new(inner.clone());
238        mem.write(MemoryEntry::new(
239            "user's card is 4111111111111111 expiry 12/30",
240        ))
241        .await
242        .unwrap();
243        let all = inner.recall("card", 10).await.unwrap();
244        assert!(all.is_empty(), "credit-card-like content should be dropped");
245    }
246
247    #[tokio::test]
248    async fn sensitive_email_is_dropped() {
249        let inner: Arc<dyn Memory> = Arc::new(VecMemory::default());
250        let mem = GuardedMemory::new(inner.clone());
251        mem.write(MemoryEntry::new("user's email is ll_faw@hotmail.com"))
252            .await
253            .unwrap();
254        let all = inner.recall("email", 10).await.unwrap();
255        assert!(all.is_empty());
256    }
257
258    #[tokio::test]
259    async fn monetary_amounts_are_dropped() {
260        let inner: Arc<dyn Memory> = Arc::new(VecMemory::default());
261        let mem = GuardedMemory::new(inner.clone());
262        mem.write(MemoryEntry::new("用户记录了一笔 ¥199 火锅消费"))
263            .await
264            .unwrap();
265        mem.write(MemoryEntry::new("user spent USD 250 on Claude Code"))
266            .await
267            .unwrap();
268        let all = inner.recall("user", 10).await.unwrap();
269        assert!(
270            all.is_empty(),
271            "monetary patterns should be filtered: {all:?}"
272        );
273    }
274
275    #[tokio::test]
276    async fn durable_preferences_pass_through() {
277        let inner: Arc<dyn Memory> = Arc::new(VecMemory::default());
278        let mem = GuardedMemory::new(inner.clone());
279        mem.write(MemoryEntry::new("用户偏好使用微信支付餐饮类支出"))
280            .await
281            .unwrap();
282        mem.write(MemoryEntry::new(
283            "user prefers concise replies in Slack style",
284        ))
285        .await
286        .unwrap();
287        let all = inner.recall("用户", 10).await.unwrap();
288        assert_eq!(all.len(), 1, "preference about 用户 should be kept");
289    }
290
291    #[tokio::test]
292    async fn duplicate_is_dropped() {
293        let inner: Arc<dyn Memory> = Arc::new(VecMemory::default());
294        let mem = GuardedMemory::new(inner.clone()).with_dedup_threshold(0.6);
295        mem.write(MemoryEntry::new(
296            "user prefers concise replies written in Slack style",
297        ))
298        .await
299        .unwrap();
300        // Near-duplicate phrasing → tokens overlap ≥ 0.6 → should be dropped.
301        mem.write(MemoryEntry::new(
302            "user prefers concise replies in Slack tone",
303        ))
304        .await
305        .unwrap();
306        let all = inner.recall("user", 10).await.unwrap();
307        assert_eq!(
308            all.len(),
309            1,
310            "near-duplicate should not double-store: {all:?}"
311        );
312    }
313
314    #[tokio::test]
315    async fn blocked_substring_works() {
316        let inner: Arc<dyn Memory> = Arc::new(VecMemory::default());
317        let mem = GuardedMemory::new(inner.clone()).with_blocked_substring("password");
318        mem.write(MemoryEntry::new("user's password reset is hunter2"))
319            .await
320            .unwrap();
321        let all = inner.recall("password", 10).await.unwrap();
322        assert!(all.is_empty());
323    }
324}