Skip to main content

pawan/
memory_fence.rs

1//! Session-scoped memory boundaries and key/content sanitization.
2
3use crate::memory::{now_rfc3339, Memory, MemoryStore};
4use crate::{PawanError, Result};
5use std::cmp::Ordering;
6use std::collections::HashSet;
7
8/// Maximum key length in Unicode scalar values (after sanitization).
9pub const MAX_KEY_CHARS: usize = 256;
10
11/// Maximum serialized memory content size (bytes).
12pub const MAX_CONTENT_BYTES: usize = 1024 * 1024;
13
14/// A memory store that is scoped to a specific session.
15/// Prevents memory from one session leaking into another.
16pub struct SessionScopedMemory {
17    store: MemoryStore,
18    session_id: String,
19}
20
21impl SessionScopedMemory {
22    pub fn new(store: MemoryStore, session_id: String) -> Self {
23        Self { store, session_id }
24    }
25
26    fn require_session(&self) -> Result<()> {
27        if self.session_id.is_empty() {
28            return Err(PawanError::Config(
29                "SessionScopedMemory requires a non-empty session_id".to_string(),
30            ));
31        }
32        Ok(())
33    }
34
35    /// Save a memory tagged with this session.
36    pub fn save(&self, memory: &Memory) -> Result<()> {
37        self.require_session()?;
38
39        let mut key = sanitize_key(&memory.key);
40        validate_key(&key)?;
41        key = self.disambiguate_key(key)?;
42
43        let now = now_rfc3339();
44        let content = sanitize_content(&memory.content);
45
46        let (created_at, relevance_score) = match self.store.load(&key) {
47            Ok(existing) if existing.source_session == self.session_id => (
48                existing.created_at,
49                memory.relevance_score.max(existing.relevance_score),
50            ),
51            Err(PawanError::NotFound(_)) => (now.clone(), memory.relevance_score),
52            Ok(_) => {
53                return Err(PawanError::Tool(
54                    "Memory key conflict after disambiguation; refusing to clobber a foreign session"
55                        .to_string(),
56                ));
57            }
58            Err(e) => return Err(e),
59        };
60
61        let to_store = Memory {
62            key,
63            content,
64            source_session: self.session_id.clone(),
65            created_at,
66            updated_at: now,
67            relevance_score,
68        };
69
70        self.store.save(&to_store)
71    }
72
73    /// Only return memories from this session (or shared cross-session knowledge).
74    pub fn get_relevant(&self, query: &str, limit: usize) -> Result<Vec<Memory>> {
75        self.require_session()?;
76        if limit == 0 {
77            return Ok(vec![]);
78        }
79
80        // Pull a larger candidate pool, then apply the session fence.
81        let pool = limit.saturating_mul(8).clamp(32, 2000);
82        let mut hits: Vec<Memory> = self
83            .store
84            .search(query, pool)?
85            .into_iter()
86            .filter(|m| m.source_session == self.session_id || m.is_shared())
87            .collect();
88
89        let mut seen: HashSet<String> = hits.iter().map(|m| m.key.clone()).collect();
90        if let Ok(keys) = self.store.list() {
91            for k in keys {
92                if seen.contains(&k) {
93                    continue;
94                }
95                if let Ok(m) = self.store.load(&k) {
96                    if m.is_shared() {
97                        seen.insert(m.key.clone());
98                        hits.push(m);
99                    }
100                }
101            }
102        }
103
104        hits.sort_by(|a, b| {
105            let s = b
106                .relevance_score
107                .partial_cmp(&a.relevance_score)
108                .unwrap_or(Ordering::Equal);
109            if s != Ordering::Equal {
110                return s;
111            }
112            b.updated_at.cmp(&a.updated_at)
113        });
114        hits.truncate(limit);
115        Ok(hits)
116    }
117
118    /// Remove session-local memories; shared memories are retained for other sessions.
119    pub fn cleanup_session(&self) -> Result<()> {
120        self.require_session()?;
121        if !self.store.base_path.exists() {
122            return Ok(());
123        }
124
125        for entry in std::fs::read_dir(&self.store.base_path)? {
126            let entry = entry?;
127            let path = entry.path();
128            if path.extension().and_then(|s| s.to_str()) != Some("json") {
129                continue;
130            }
131            let bytes = match std::fs::read(&path) {
132                Ok(b) => b,
133                Err(_) => continue,
134            };
135            let mem: Memory = match serde_json::from_slice(&bytes) {
136                Ok(m) => m,
137                Err(_) => continue,
138            };
139            if mem.source_session == self.session_id && !mem.is_shared() {
140                self.store.delete(&mem.key)?;
141            }
142        }
143        Ok(())
144    }
145
146    fn disambiguate_key(&self, base: String) -> Result<String> {
147        let original = base.clone();
148        let mut candidate = base;
149        let mut n = 0u32;
150
151        loop {
152            match self.store.load(&candidate) {
153                Ok(existing) if existing.source_session == self.session_id => {
154                    return Ok(candidate);
155                }
156                Ok(_other) => {
157                    n += 1;
158                    let suffix = format!("__{n}");
159                    let max_base = MAX_KEY_CHARS.saturating_sub(suffix.chars().count());
160                    if max_base == 0 {
161                        return Err(PawanError::Tool(
162                            "Could not reserve space for a disambiguation suffix on the memory key"
163                                .to_string(),
164                        ));
165                    }
166                    let truncated = truncate_to_max_chars(&original, max_base);
167                    candidate = format!("{truncated}{suffix}");
168                }
169                Err(PawanError::NotFound(_)) => return Ok(candidate),
170                Err(e) => return Err(e),
171            }
172        }
173    }
174}
175
176/// Sanitize a string to prevent injection in memory keys: keep alnum, dash, underscore, dot.
177pub fn sanitize_key(s: &str) -> String {
178    s.chars()
179        .filter(|ch| ch.is_ascii_alphanumeric() || *ch == '-' || *ch == '_' || *ch == '.')
180        .collect()
181}
182
183/// Sanitize a memory content string: strip NULs and cap size at 1MB (byte length).
184pub fn sanitize_content(s: &str) -> String {
185    let no_nul: String = s.chars().filter(|&c| c != '\0').collect();
186    truncate_to_max_bytes(&no_nul, MAX_CONTENT_BYTES)
187}
188
189/// Validate that a memory key is safe for filesystem use.
190pub fn validate_key(key: &str) -> Result<()> {
191    if key.is_empty() {
192        return Err(PawanError::Tool(
193            "Memory key is empty (or became empty after sanitization)".to_string(),
194        ));
195    }
196    if key.chars().count() > MAX_KEY_CHARS {
197        return Err(PawanError::Tool(format!(
198            "Memory key exceeds {MAX_KEY_CHARS} characters"
199        )));
200    }
201    if !key
202        .chars()
203        .all(|ch| ch.is_ascii_alphanumeric() || ch == '-' || ch == '_' || ch == '.')
204    {
205        return Err(PawanError::Tool(
206            "Memory key contains disallowed characters (allowed: A-Z, a-z, 0-9, -, _, .)"
207                .to_string(),
208        ));
209    }
210    Ok(())
211}
212
213fn truncate_to_max_bytes(s: &str, max: usize) -> String {
214    if s.len() <= max {
215        return s.to_string();
216    }
217    let mut end = max;
218    while end > 0 && !s.is_char_boundary(end) {
219        end -= 1;
220    }
221    s[..end].to_string()
222}
223
224fn truncate_to_max_chars(s: &str, max_chars: usize) -> String {
225    if s.chars().count() <= max_chars {
226        return s.to_string();
227    }
228    s.chars().take(max_chars).collect()
229}
230
231#[cfg(test)]
232mod tests {
233    use super::*;
234    use tempfile::TempDir;
235
236    #[test]
237    fn sanitize_strips_unsafe_key_chars() {
238        assert_eq!(sanitize_key("a/b@x#y"), "abxy");
239        assert_eq!(sanitize_key("arch.module-name"), "arch.module-name");
240    }
241
242    #[test]
243    fn validate_key_rejects_bad_keys() {
244        assert!(validate_key("").is_err());
245        assert!(validate_key("bad/key").is_err());
246        let long: String = "a".repeat(MAX_KEY_CHARS + 1);
247        assert!(validate_key(&long).is_err());
248    }
249
250    #[test]
251    fn sanitize_content_strips_nul_and_truncates() {
252        let s = "a\0b".repeat(MAX_CONTENT_BYTES);
253        let out = sanitize_content(&s);
254        assert!(!out.contains('\0'));
255        assert!(out.len() <= MAX_CONTENT_BYTES);
256    }
257
258    #[test]
259    fn session_fence_filters_foreign_session() {
260        let dir = TempDir::new().unwrap();
261        let store = MemoryStore::new(dir.path().join("memories"));
262
263        let mem_a = Memory {
264            key: "note.a".to_string(),
265            content: "local debug for session A".to_string(),
266            source_session: "sess-a".to_string(),
267            created_at: now_rfc3339(),
268            updated_at: now_rfc3339(),
269            relevance_score: 1.0,
270        };
271        let mem_b = Memory {
272            key: "note.b".to_string(),
273            content: "Architecture decision: use modules".to_string(),
274            source_session: "sess-b".to_string(),
275            created_at: now_rfc3339(),
276            updated_at: now_rfc3339(),
277            relevance_score: 1.0,
278        };
279        let mem_c = Memory {
280            key: "note.c".to_string(),
281            content: "Private session B debug scratchpad".to_string(),
282            source_session: "sess-b".to_string(),
283            created_at: now_rfc3339(),
284            updated_at: now_rfc3339(),
285            relevance_score: 1.0,
286        };
287        store.save(&mem_a).unwrap();
288        store.save(&mem_b).unwrap();
289        store.save(&mem_c).unwrap();
290
291        let scoped = SessionScopedMemory::new(store, "sess-a".to_string());
292        let found = scoped.get_relevant("debug", 10).unwrap();
293        let keys: Vec<_> = found.iter().map(|m| m.key.as_str()).collect();
294        assert!(keys.contains(&"note.a"));
295        assert!(keys.contains(&"note.b"));
296        assert!(!keys.contains(&"note.c"));
297    }
298
299    #[test]
300    fn test_session_scoped_memory_requires_non_empty_session_id() {
301        let dir = TempDir::new().unwrap();
302        let store = MemoryStore::new(dir.path().join("memories"));
303        let scoped = SessionScopedMemory::new(store, String::new());
304        let m = Memory {
305            key: "k".to_string(),
306            content: "c".to_string(),
307            source_session: String::new(),
308            created_at: now_rfc3339(),
309            updated_at: now_rfc3339(),
310            relevance_score: 0.1,
311        };
312        assert!(scoped.save(&m).is_err());
313    }
314
315    #[test]
316    fn test_get_relevant_empty_query_returns_empty() {
317        let dir = TempDir::new().unwrap();
318        let store = MemoryStore::new(dir.path().join("memories"));
319        let scoped = SessionScopedMemory::new(store, "s".to_string());
320        let out = scoped.get_relevant("   ", 10).unwrap();
321        assert!(out.is_empty());
322    }
323
324    #[test]
325    fn test_sanitize_and_validate_key_edge_cases() {
326        assert_eq!(sanitize_key(""), "");
327        assert_eq!(sanitize_key("a@b"), "ab");
328        assert!(validate_key("valid.key-1_").is_ok());
329        let empty_content = sanitize_content("");
330        assert!(empty_content.is_empty());
331        let big = "x".repeat(MAX_CONTENT_BYTES + 10_000);
332        let capped = sanitize_content(&big);
333        assert!(capped.len() <= MAX_CONTENT_BYTES);
334    }
335}