Skip to main content

pawan/
memory_fence.rs

1//! Session-scoped memory boundaries and key/content sanitization.
2
3use crate::memory::{Memory, MemoryStore, now_rfc3339};
4use crate::{PawanError, Result};
5use std::cmp::Ordering;
6use std::collections::HashSet;
7
8/// Maximum key length in Unicode scalar values (after sanitization).
9pub const MAX_KEY_CHARS: usize = 256;
10
11/// Maximum serialized memory content size (bytes).
12pub const MAX_CONTENT_BYTES: usize = 1024 * 1024;
13
14/// A memory store that is scoped to a specific session.
15/// Prevents memory from one session leaking into another.
16pub struct SessionScopedMemory {
17    store: MemoryStore,
18    session_id: String,
19}
20
21impl SessionScopedMemory {
22    pub fn new(store: MemoryStore, session_id: String) -> Self {
23        Self { store, session_id }
24    }
25
26    fn require_session(&self) -> Result<()> {
27        if self.session_id.is_empty() {
28            return Err(PawanError::Config(
29                "SessionScopedMemory requires a non-empty session_id".to_string(),
30            ));
31        }
32        Ok(())
33    }
34
35    /// Save a memory tagged with this session.
36    pub fn save(&self, memory: &Memory) -> Result<()> {
37        self.require_session()?;
38
39        let mut key = sanitize_key(&memory.key);
40        validate_key(&key)?;
41        key = self.disambiguate_key(key)?;
42
43        let now = now_rfc3339();
44        let content = sanitize_content(&memory.content);
45
46        let (created_at, relevance_score) = match self.store.load(&key) {
47            Ok(existing) if existing.source_session == self.session_id => (
48                existing.created_at,
49                memory.relevance_score.max(existing.relevance_score),
50            ),
51            Err(PawanError::NotFound(_)) => (now.clone(), memory.relevance_score),
52            Ok(_) => {
53                return Err(PawanError::Tool(
54                    "Memory key conflict after disambiguation; refusing to clobber a foreign session"
55                        .to_string(),
56                ));
57            }
58            Err(e) => return Err(e),
59        };
60
61        let to_store = Memory {
62            key,
63            content,
64            source_session: self.session_id.clone(),
65            created_at,
66            updated_at: now,
67            relevance_score,
68        };
69
70        self.store.save(&to_store)
71    }
72
73    /// Only return memories from this session (or shared cross-session knowledge).
74    pub fn get_relevant(&self, query: &str, limit: usize) -> Result<Vec<Memory>> {
75        self.require_session()?;
76        if limit == 0 {
77            return Ok(vec![]);
78        }
79
80        // Pull a larger candidate pool, then apply the session fence.
81        let pool = (limit.saturating_mul(8)).max(32).min(2000);
82        let mut hits: Vec<Memory> = self
83            .store
84            .search(query, pool)?
85            .into_iter()
86            .filter(|m| m.source_session == self.session_id || m.is_shared())
87            .collect();
88
89        let mut seen: HashSet<String> = hits.iter().map(|m| m.key.clone()).collect();
90        if let Ok(keys) = self.store.list() {
91            for k in keys {
92                if seen.contains(&k) {
93                    continue;
94                }
95                if let Ok(m) = self.store.load(&k) {
96                    if m.is_shared() {
97                        seen.insert(m.key.clone());
98                        hits.push(m);
99                    }
100                }
101            }
102        }
103
104        hits.sort_by(|a, b| {
105            let s = b
106                .relevance_score
107                .partial_cmp(&a.relevance_score)
108                .unwrap_or(Ordering::Equal);
109            if s != Ordering::Equal {
110                return s;
111            }
112            b.updated_at.cmp(&a.updated_at)
113        });
114        hits.truncate(limit);
115        Ok(hits)
116    }
117
118    /// Remove session-local memories; shared memories are retained for other sessions.
119    pub fn cleanup_session(&self) -> Result<()> {
120        self.require_session()?;
121        if !self.store.base_path.exists() {
122            return Ok(());
123        }
124
125        for entry in std::fs::read_dir(&self.store.base_path)? {
126            let entry = entry?;
127            let path = entry.path();
128            if path.extension().and_then(|s| s.to_str()) != Some("json") {
129                continue;
130            }
131            let bytes = match std::fs::read(&path) {
132                Ok(b) => b,
133                Err(_) => continue,
134            };
135            let mem: Memory = match serde_json::from_slice(&bytes) {
136                Ok(m) => m,
137                Err(_) => continue,
138            };
139            if mem.source_session == self.session_id && !mem.is_shared() {
140                self.store.delete(&mem.key)?;
141            }
142        }
143        Ok(())
144    }
145
146    fn disambiguate_key(&self, base: String) -> Result<String> {
147        let original = base.clone();
148        let mut candidate = base;
149        let mut n = 0u32;
150
151        loop {
152            match self.store.load(&candidate) {
153                Ok(existing) if existing.source_session == self.session_id => {
154                    return Ok(candidate);
155                }
156                Ok(_other) => {
157                    n += 1;
158                    let suffix = format!("__{n}");
159                    let max_base = MAX_KEY_CHARS.saturating_sub(suffix.chars().count());
160                    if max_base == 0 {
161                        return Err(PawanError::Tool(
162                            "Could not reserve space for a disambiguation suffix on the memory key"
163                                .to_string(),
164                        ));
165                    }
166                    let truncated = truncate_to_max_chars(&original, max_base);
167                    candidate = format!("{truncated}{suffix}");
168                }
169                Err(PawanError::NotFound(_)) => return Ok(candidate),
170                Err(e) => return Err(e),
171            }
172        }
173    }
174}
175
176/// Sanitize a string to prevent injection in memory keys: keep alnum, dash, underscore, dot.
177pub fn sanitize_key(s: &str) -> String {
178    s.chars()
179        .filter(|ch| {
180            ch.is_ascii_alphanumeric() || *ch == '-' || *ch == '_' || *ch == '.'
181        })
182        .collect()
183}
184
185/// Sanitize a memory content string: strip NULs and cap size at 1MB (byte length).
186pub fn sanitize_content(s: &str) -> String {
187    let no_nul: String = s.chars().filter(|&c| c != '\0').collect();
188    truncate_to_max_bytes(&no_nul, MAX_CONTENT_BYTES)
189}
190
191/// Validate that a memory key is safe for filesystem use.
192pub fn validate_key(key: &str) -> Result<()> {
193    if key.is_empty() {
194        return Err(PawanError::Tool(
195            "Memory key is empty (or became empty after sanitization)".to_string(),
196        ));
197    }
198    if key.chars().count() > MAX_KEY_CHARS {
199        return Err(PawanError::Tool(format!(
200            "Memory key exceeds {MAX_KEY_CHARS} characters"
201        )));
202    }
203    if !key
204        .chars()
205        .all(|ch| ch.is_ascii_alphanumeric() || ch == '-' || ch == '_' || ch == '.')
206    {
207        return Err(PawanError::Tool(
208            "Memory key contains disallowed characters (allowed: A-Z, a-z, 0-9, -, _, .)"
209                .to_string(),
210        ));
211    }
212    Ok(())
213}
214
215fn truncate_to_max_bytes(s: &str, max: usize) -> String {
216    if s.len() <= max {
217        return s.to_string();
218    }
219    let mut end = max;
220    while end > 0 && !s.is_char_boundary(end) {
221        end -= 1;
222    }
223    s[..end].to_string()
224}
225
226fn truncate_to_max_chars(s: &str, max_chars: usize) -> String {
227    if s.chars().count() <= max_chars {
228        return s.to_string();
229    }
230    s.chars().take(max_chars).collect()
231}
232
233#[cfg(test)]
234mod tests {
235    use super::*;
236    use tempfile::TempDir;
237
238    #[test]
239    fn sanitize_strips_unsafe_key_chars() {
240        assert_eq!(sanitize_key("a/b@x#y"), "abxy");
241        assert_eq!(sanitize_key("arch.module-name"), "arch.module-name");
242    }
243
244    #[test]
245    fn validate_key_rejects_bad_keys() {
246        assert!(validate_key("").is_err());
247        assert!(validate_key("bad/key").is_err());
248        let long: String = "a".repeat(MAX_KEY_CHARS + 1);
249        assert!(validate_key(&long).is_err());
250    }
251
252    #[test]
253    fn sanitize_content_strips_nul_and_truncates() {
254        let s = "a\0b".repeat(MAX_CONTENT_BYTES);
255        let out = sanitize_content(&s);
256        assert!(!out.contains('\0'));
257        assert!(out.len() <= MAX_CONTENT_BYTES);
258    }
259
260    #[test]
261    fn session_fence_filters_foreign_session() {
262        let dir = TempDir::new().unwrap();
263        let store = MemoryStore::new(dir.path().join("memories"));
264
265        let mem_a = Memory {
266            key: "note.a".to_string(),
267            content: "local debug for session A".to_string(),
268            source_session: "sess-a".to_string(),
269            created_at: now_rfc3339(),
270            updated_at: now_rfc3339(),
271            relevance_score: 1.0,
272        };
273        let mem_b = Memory {
274            key: "note.b".to_string(),
275            content: "Architecture decision: use modules".to_string(),
276            source_session: "sess-b".to_string(),
277            created_at: now_rfc3339(),
278            updated_at: now_rfc3339(),
279            relevance_score: 1.0,
280        };
281        let mem_c = Memory {
282            key: "note.c".to_string(),
283            content: "Private session B debug scratchpad".to_string(),
284            source_session: "sess-b".to_string(),
285            created_at: now_rfc3339(),
286            updated_at: now_rfc3339(),
287            relevance_score: 1.0,
288        };
289        store.save(&mem_a).unwrap();
290        store.save(&mem_b).unwrap();
291        store.save(&mem_c).unwrap();
292
293        let scoped = SessionScopedMemory::new(store, "sess-a".to_string());
294        let found = scoped.get_relevant("debug", 10).unwrap();
295        let keys: Vec<_> = found.iter().map(|m| m.key.as_str()).collect();
296        assert!(keys.contains(&"note.a"));
297        assert!(keys.contains(&"note.b"));
298        assert!(!keys.contains(&"note.c"));
299    }
300
301    #[test]
302    fn test_session_scoped_memory_requires_non_empty_session_id() {
303        let dir = TempDir::new().unwrap();
304        let store = MemoryStore::new(dir.path().join("memories"));
305        let scoped = SessionScopedMemory::new(store, String::new());
306        let m = Memory {
307            key: "k".to_string(),
308            content: "c".to_string(),
309            source_session: String::new(),
310            created_at: now_rfc3339(),
311            updated_at: now_rfc3339(),
312            relevance_score: 0.1,
313        };
314        assert!(scoped.save(&m).is_err());
315    }
316
317    #[test]
318    fn test_get_relevant_empty_query_returns_empty() {
319        let dir = TempDir::new().unwrap();
320        let store = MemoryStore::new(dir.path().join("memories"));
321        let scoped = SessionScopedMemory::new(store, "s".to_string());
322        let out = scoped.get_relevant("   ", 10).unwrap();
323        assert!(out.is_empty());
324    }
325
326    #[test]
327    fn test_sanitize_and_validate_key_edge_cases() {
328        assert_eq!(sanitize_key(""), "");
329        assert_eq!(sanitize_key("a@b"), "ab");
330        assert!(validate_key("valid.key-1_").is_ok());
331        let empty_content = sanitize_content("");
332        assert!(empty_content.is_empty());
333        let big = "x".repeat(MAX_CONTENT_BYTES + 10_000);
334        let capped = sanitize_content(&big);
335        assert!(capped.len() <= MAX_CONTENT_BYTES);
336    }
337}