1use crate::memory::{Memory, MemoryStore, now_rfc3339};
4use crate::{PawanError, Result};
5use std::cmp::Ordering;
6use std::collections::HashSet;
7
8pub const MAX_KEY_CHARS: usize = 256;
10
11pub const MAX_CONTENT_BYTES: usize = 1024 * 1024;
13
14pub struct SessionScopedMemory {
17 store: MemoryStore,
18 session_id: String,
19}
20
21impl SessionScopedMemory {
22 pub fn new(store: MemoryStore, session_id: String) -> Self {
23 Self { store, session_id }
24 }
25
26 fn require_session(&self) -> Result<()> {
27 if self.session_id.is_empty() {
28 return Err(PawanError::Config(
29 "SessionScopedMemory requires a non-empty session_id".to_string(),
30 ));
31 }
32 Ok(())
33 }
34
35 pub fn save(&self, memory: &Memory) -> Result<()> {
37 self.require_session()?;
38
39 let mut key = sanitize_key(&memory.key);
40 validate_key(&key)?;
41 key = self.disambiguate_key(key)?;
42
43 let now = now_rfc3339();
44 let content = sanitize_content(&memory.content);
45
46 let (created_at, relevance_score) = match self.store.load(&key) {
47 Ok(existing) if existing.source_session == self.session_id => (
48 existing.created_at,
49 memory.relevance_score.max(existing.relevance_score),
50 ),
51 Err(PawanError::NotFound(_)) => (now.clone(), memory.relevance_score),
52 Ok(_) => {
53 return Err(PawanError::Tool(
54 "Memory key conflict after disambiguation; refusing to clobber a foreign session"
55 .to_string(),
56 ));
57 }
58 Err(e) => return Err(e),
59 };
60
61 let to_store = Memory {
62 key,
63 content,
64 source_session: self.session_id.clone(),
65 created_at,
66 updated_at: now,
67 relevance_score,
68 };
69
70 self.store.save(&to_store)
71 }
72
73 pub fn get_relevant(&self, query: &str, limit: usize) -> Result<Vec<Memory>> {
75 self.require_session()?;
76 if limit == 0 {
77 return Ok(vec![]);
78 }
79
80 let pool = (limit.saturating_mul(8)).max(32).min(2000);
82 let mut hits: Vec<Memory> = self
83 .store
84 .search(query, pool)?
85 .into_iter()
86 .filter(|m| m.source_session == self.session_id || m.is_shared())
87 .collect();
88
89 let mut seen: HashSet<String> = hits.iter().map(|m| m.key.clone()).collect();
90 if let Ok(keys) = self.store.list() {
91 for k in keys {
92 if seen.contains(&k) {
93 continue;
94 }
95 if let Ok(m) = self.store.load(&k) {
96 if m.is_shared() {
97 seen.insert(m.key.clone());
98 hits.push(m);
99 }
100 }
101 }
102 }
103
104 hits.sort_by(|a, b| {
105 let s = b
106 .relevance_score
107 .partial_cmp(&a.relevance_score)
108 .unwrap_or(Ordering::Equal);
109 if s != Ordering::Equal {
110 return s;
111 }
112 b.updated_at.cmp(&a.updated_at)
113 });
114 hits.truncate(limit);
115 Ok(hits)
116 }
117
118 pub fn cleanup_session(&self) -> Result<()> {
120 self.require_session()?;
121 if !self.store.base_path.exists() {
122 return Ok(());
123 }
124
125 for entry in std::fs::read_dir(&self.store.base_path)? {
126 let entry = entry?;
127 let path = entry.path();
128 if path.extension().and_then(|s| s.to_str()) != Some("json") {
129 continue;
130 }
131 let bytes = match std::fs::read(&path) {
132 Ok(b) => b,
133 Err(_) => continue,
134 };
135 let mem: Memory = match serde_json::from_slice(&bytes) {
136 Ok(m) => m,
137 Err(_) => continue,
138 };
139 if mem.source_session == self.session_id && !mem.is_shared() {
140 self.store.delete(&mem.key)?;
141 }
142 }
143 Ok(())
144 }
145
146 fn disambiguate_key(&self, base: String) -> Result<String> {
147 let original = base.clone();
148 let mut candidate = base;
149 let mut n = 0u32;
150
151 loop {
152 match self.store.load(&candidate) {
153 Ok(existing) if existing.source_session == self.session_id => {
154 return Ok(candidate);
155 }
156 Ok(_other) => {
157 n += 1;
158 let suffix = format!("__{n}");
159 let max_base = MAX_KEY_CHARS.saturating_sub(suffix.chars().count());
160 if max_base == 0 {
161 return Err(PawanError::Tool(
162 "Could not reserve space for a disambiguation suffix on the memory key"
163 .to_string(),
164 ));
165 }
166 let truncated = truncate_to_max_chars(&original, max_base);
167 candidate = format!("{truncated}{suffix}");
168 }
169 Err(PawanError::NotFound(_)) => return Ok(candidate),
170 Err(e) => return Err(e),
171 }
172 }
173 }
174}
175
176pub fn sanitize_key(s: &str) -> String {
178 s.chars()
179 .filter(|ch| {
180 ch.is_ascii_alphanumeric() || *ch == '-' || *ch == '_' || *ch == '.'
181 })
182 .collect()
183}
184
185pub fn sanitize_content(s: &str) -> String {
187 let no_nul: String = s.chars().filter(|&c| c != '\0').collect();
188 truncate_to_max_bytes(&no_nul, MAX_CONTENT_BYTES)
189}
190
191pub fn validate_key(key: &str) -> Result<()> {
193 if key.is_empty() {
194 return Err(PawanError::Tool(
195 "Memory key is empty (or became empty after sanitization)".to_string(),
196 ));
197 }
198 if key.chars().count() > MAX_KEY_CHARS {
199 return Err(PawanError::Tool(format!(
200 "Memory key exceeds {MAX_KEY_CHARS} characters"
201 )));
202 }
203 if !key
204 .chars()
205 .all(|ch| ch.is_ascii_alphanumeric() || ch == '-' || ch == '_' || ch == '.')
206 {
207 return Err(PawanError::Tool(
208 "Memory key contains disallowed characters (allowed: A-Z, a-z, 0-9, -, _, .)"
209 .to_string(),
210 ));
211 }
212 Ok(())
213}
214
215fn truncate_to_max_bytes(s: &str, max: usize) -> String {
216 if s.len() <= max {
217 return s.to_string();
218 }
219 let mut end = max;
220 while end > 0 && !s.is_char_boundary(end) {
221 end -= 1;
222 }
223 s[..end].to_string()
224}
225
226fn truncate_to_max_chars(s: &str, max_chars: usize) -> String {
227 if s.chars().count() <= max_chars {
228 return s.to_string();
229 }
230 s.chars().take(max_chars).collect()
231}
232
233#[cfg(test)]
234mod tests {
235 use super::*;
236 use tempfile::TempDir;
237
238 #[test]
239 fn sanitize_strips_unsafe_key_chars() {
240 assert_eq!(sanitize_key("a/b@x#y"), "abxy");
241 assert_eq!(sanitize_key("arch.module-name"), "arch.module-name");
242 }
243
244 #[test]
245 fn validate_key_rejects_bad_keys() {
246 assert!(validate_key("").is_err());
247 assert!(validate_key("bad/key").is_err());
248 let long: String = "a".repeat(MAX_KEY_CHARS + 1);
249 assert!(validate_key(&long).is_err());
250 }
251
252 #[test]
253 fn sanitize_content_strips_nul_and_truncates() {
254 let s = "a\0b".repeat(MAX_CONTENT_BYTES);
255 let out = sanitize_content(&s);
256 assert!(!out.contains('\0'));
257 assert!(out.len() <= MAX_CONTENT_BYTES);
258 }
259
260 #[test]
261 fn session_fence_filters_foreign_session() {
262 let dir = TempDir::new().unwrap();
263 let store = MemoryStore::new(dir.path().join("memories"));
264
265 let mem_a = Memory {
266 key: "note.a".to_string(),
267 content: "local debug for session A".to_string(),
268 source_session: "sess-a".to_string(),
269 created_at: now_rfc3339(),
270 updated_at: now_rfc3339(),
271 relevance_score: 1.0,
272 };
273 let mem_b = Memory {
274 key: "note.b".to_string(),
275 content: "Architecture decision: use modules".to_string(),
276 source_session: "sess-b".to_string(),
277 created_at: now_rfc3339(),
278 updated_at: now_rfc3339(),
279 relevance_score: 1.0,
280 };
281 let mem_c = Memory {
282 key: "note.c".to_string(),
283 content: "Private session B debug scratchpad".to_string(),
284 source_session: "sess-b".to_string(),
285 created_at: now_rfc3339(),
286 updated_at: now_rfc3339(),
287 relevance_score: 1.0,
288 };
289 store.save(&mem_a).unwrap();
290 store.save(&mem_b).unwrap();
291 store.save(&mem_c).unwrap();
292
293 let scoped = SessionScopedMemory::new(store, "sess-a".to_string());
294 let found = scoped.get_relevant("debug", 10).unwrap();
295 let keys: Vec<_> = found.iter().map(|m| m.key.as_str()).collect();
296 assert!(keys.contains(&"note.a"));
297 assert!(keys.contains(&"note.b"));
298 assert!(!keys.contains(&"note.c"));
299 }
300
301 #[test]
302 fn test_session_scoped_memory_requires_non_empty_session_id() {
303 let dir = TempDir::new().unwrap();
304 let store = MemoryStore::new(dir.path().join("memories"));
305 let scoped = SessionScopedMemory::new(store, String::new());
306 let m = Memory {
307 key: "k".to_string(),
308 content: "c".to_string(),
309 source_session: String::new(),
310 created_at: now_rfc3339(),
311 updated_at: now_rfc3339(),
312 relevance_score: 0.1,
313 };
314 assert!(scoped.save(&m).is_err());
315 }
316
317 #[test]
318 fn test_get_relevant_empty_query_returns_empty() {
319 let dir = TempDir::new().unwrap();
320 let store = MemoryStore::new(dir.path().join("memories"));
321 let scoped = SessionScopedMemory::new(store, "s".to_string());
322 let out = scoped.get_relevant(" ", 10).unwrap();
323 assert!(out.is_empty());
324 }
325
326 #[test]
327 fn test_sanitize_and_validate_key_edge_cases() {
328 assert_eq!(sanitize_key(""), "");
329 assert_eq!(sanitize_key("a@b"), "ab");
330 assert!(validate_key("valid.key-1_").is_ok());
331 let empty_content = sanitize_content("");
332 assert!(empty_content.is_empty());
333 let big = "x".repeat(MAX_CONTENT_BYTES + 10_000);
334 let capped = sanitize_content(&big);
335 assert!(capped.len() <= MAX_CONTENT_BYTES);
336 }
337}