1use crate::memory::{now_rfc3339, Memory, MemoryStore};
4use crate::{PawanError, Result};
5use std::cmp::Ordering;
6use std::collections::HashSet;
7
8pub const MAX_KEY_CHARS: usize = 256;
10
11pub const MAX_CONTENT_BYTES: usize = 1024 * 1024;
13
14pub struct SessionScopedMemory {
17 store: MemoryStore,
18 session_id: String,
19}
20
21impl SessionScopedMemory {
22 pub fn new(store: MemoryStore, session_id: String) -> Self {
23 Self { store, session_id }
24 }
25
26 fn require_session(&self) -> Result<()> {
27 if self.session_id.is_empty() {
28 return Err(PawanError::Config(
29 "SessionScopedMemory requires a non-empty session_id".to_string(),
30 ));
31 }
32 Ok(())
33 }
34
35 pub fn save(&self, memory: &Memory) -> Result<()> {
37 self.require_session()?;
38
39 let mut key = sanitize_key(&memory.key);
40 validate_key(&key)?;
41 key = self.disambiguate_key(key)?;
42
43 let now = now_rfc3339();
44 let content = sanitize_content(&memory.content);
45
46 let (created_at, relevance_score) = match self.store.load(&key) {
47 Ok(existing) if existing.source_session == self.session_id => (
48 existing.created_at,
49 memory.relevance_score.max(existing.relevance_score),
50 ),
51 Err(PawanError::NotFound(_)) => (now.clone(), memory.relevance_score),
52 Ok(_) => {
53 return Err(PawanError::Tool(
54 "Memory key conflict after disambiguation; refusing to clobber a foreign session"
55 .to_string(),
56 ));
57 }
58 Err(e) => return Err(e),
59 };
60
61 let to_store = Memory {
62 key,
63 content,
64 source_session: self.session_id.clone(),
65 created_at,
66 updated_at: now,
67 relevance_score,
68 };
69
70 self.store.save(&to_store)
71 }
72
73 pub fn get_relevant(&self, query: &str, limit: usize) -> Result<Vec<Memory>> {
75 self.require_session()?;
76 if limit == 0 {
77 return Ok(vec![]);
78 }
79
80 let pool = limit.saturating_mul(8).clamp(32, 2000);
82 let mut hits: Vec<Memory> = self
83 .store
84 .search(query, pool)?
85 .into_iter()
86 .filter(|m| m.source_session == self.session_id || m.is_shared())
87 .collect();
88
89 let mut seen: HashSet<String> = hits.iter().map(|m| m.key.clone()).collect();
90 if let Ok(keys) = self.store.list() {
91 for k in keys {
92 if seen.contains(&k) {
93 continue;
94 }
95 if let Ok(m) = self.store.load(&k) {
96 if m.is_shared() {
97 seen.insert(m.key.clone());
98 hits.push(m);
99 }
100 }
101 }
102 }
103
104 hits.sort_by(|a, b| {
105 let s = b
106 .relevance_score
107 .partial_cmp(&a.relevance_score)
108 .unwrap_or(Ordering::Equal);
109 if s != Ordering::Equal {
110 return s;
111 }
112 b.updated_at.cmp(&a.updated_at)
113 });
114 hits.truncate(limit);
115 Ok(hits)
116 }
117
118 pub fn cleanup_session(&self) -> Result<()> {
120 self.require_session()?;
121 if !self.store.base_path.exists() {
122 return Ok(());
123 }
124
125 for entry in std::fs::read_dir(&self.store.base_path)? {
126 let entry = entry?;
127 let path = entry.path();
128 if path.extension().and_then(|s| s.to_str()) != Some("json") {
129 continue;
130 }
131 let bytes = match std::fs::read(&path) {
132 Ok(b) => b,
133 Err(_) => continue,
134 };
135 let mem: Memory = match serde_json::from_slice(&bytes) {
136 Ok(m) => m,
137 Err(_) => continue,
138 };
139 if mem.source_session == self.session_id && !mem.is_shared() {
140 self.store.delete(&mem.key)?;
141 }
142 }
143 Ok(())
144 }
145
146 fn disambiguate_key(&self, base: String) -> Result<String> {
147 let original = base.clone();
148 let mut candidate = base;
149 let mut n = 0u32;
150
151 loop {
152 match self.store.load(&candidate) {
153 Ok(existing) if existing.source_session == self.session_id => {
154 return Ok(candidate);
155 }
156 Ok(_other) => {
157 n += 1;
158 let suffix = format!("__{n}");
159 let max_base = MAX_KEY_CHARS.saturating_sub(suffix.chars().count());
160 if max_base == 0 {
161 return Err(PawanError::Tool(
162 "Could not reserve space for a disambiguation suffix on the memory key"
163 .to_string(),
164 ));
165 }
166 let truncated = truncate_to_max_chars(&original, max_base);
167 candidate = format!("{truncated}{suffix}");
168 }
169 Err(PawanError::NotFound(_)) => return Ok(candidate),
170 Err(e) => return Err(e),
171 }
172 }
173 }
174}
175
176pub fn sanitize_key(s: &str) -> String {
178 s.chars()
179 .filter(|ch| ch.is_ascii_alphanumeric() || *ch == '-' || *ch == '_' || *ch == '.')
180 .collect()
181}
182
183pub fn sanitize_content(s: &str) -> String {
185 let no_nul: String = s.chars().filter(|&c| c != '\0').collect();
186 truncate_to_max_bytes(&no_nul, MAX_CONTENT_BYTES)
187}
188
189pub fn validate_key(key: &str) -> Result<()> {
191 if key.is_empty() {
192 return Err(PawanError::Tool(
193 "Memory key is empty (or became empty after sanitization)".to_string(),
194 ));
195 }
196 if key.chars().count() > MAX_KEY_CHARS {
197 return Err(PawanError::Tool(format!(
198 "Memory key exceeds {MAX_KEY_CHARS} characters"
199 )));
200 }
201 if !key
202 .chars()
203 .all(|ch| ch.is_ascii_alphanumeric() || ch == '-' || ch == '_' || ch == '.')
204 {
205 return Err(PawanError::Tool(
206 "Memory key contains disallowed characters (allowed: A-Z, a-z, 0-9, -, _, .)"
207 .to_string(),
208 ));
209 }
210 Ok(())
211}
212
213fn truncate_to_max_bytes(s: &str, max: usize) -> String {
214 if s.len() <= max {
215 return s.to_string();
216 }
217 let mut end = max;
218 while end > 0 && !s.is_char_boundary(end) {
219 end -= 1;
220 }
221 s[..end].to_string()
222}
223
224fn truncate_to_max_chars(s: &str, max_chars: usize) -> String {
225 if s.chars().count() <= max_chars {
226 return s.to_string();
227 }
228 s.chars().take(max_chars).collect()
229}
230
231#[cfg(test)]
232mod tests {
233 use super::*;
234 use tempfile::TempDir;
235
236 #[test]
237 fn sanitize_strips_unsafe_key_chars() {
238 assert_eq!(sanitize_key("a/b@x#y"), "abxy");
239 assert_eq!(sanitize_key("arch.module-name"), "arch.module-name");
240 }
241
242 #[test]
243 fn validate_key_rejects_bad_keys() {
244 assert!(validate_key("").is_err());
245 assert!(validate_key("bad/key").is_err());
246 let long: String = "a".repeat(MAX_KEY_CHARS + 1);
247 assert!(validate_key(&long).is_err());
248 }
249
250 #[test]
251 fn sanitize_content_strips_nul_and_truncates() {
252 let s = "a\0b".repeat(MAX_CONTENT_BYTES);
253 let out = sanitize_content(&s);
254 assert!(!out.contains('\0'));
255 assert!(out.len() <= MAX_CONTENT_BYTES);
256 }
257
258 #[test]
259 fn session_fence_filters_foreign_session() {
260 let dir = TempDir::new().unwrap();
261 let store = MemoryStore::new(dir.path().join("memories"));
262
263 let mem_a = Memory {
264 key: "note.a".to_string(),
265 content: "local debug for session A".to_string(),
266 source_session: "sess-a".to_string(),
267 created_at: now_rfc3339(),
268 updated_at: now_rfc3339(),
269 relevance_score: 1.0,
270 };
271 let mem_b = Memory {
272 key: "note.b".to_string(),
273 content: "Architecture decision: use modules".to_string(),
274 source_session: "sess-b".to_string(),
275 created_at: now_rfc3339(),
276 updated_at: now_rfc3339(),
277 relevance_score: 1.0,
278 };
279 let mem_c = Memory {
280 key: "note.c".to_string(),
281 content: "Private session B debug scratchpad".to_string(),
282 source_session: "sess-b".to_string(),
283 created_at: now_rfc3339(),
284 updated_at: now_rfc3339(),
285 relevance_score: 1.0,
286 };
287 store.save(&mem_a).unwrap();
288 store.save(&mem_b).unwrap();
289 store.save(&mem_c).unwrap();
290
291 let scoped = SessionScopedMemory::new(store, "sess-a".to_string());
292 let found = scoped.get_relevant("debug", 10).unwrap();
293 let keys: Vec<_> = found.iter().map(|m| m.key.as_str()).collect();
294 assert!(keys.contains(&"note.a"));
295 assert!(keys.contains(&"note.b"));
296 assert!(!keys.contains(&"note.c"));
297 }
298
299 #[test]
300 fn test_session_scoped_memory_requires_non_empty_session_id() {
301 let dir = TempDir::new().unwrap();
302 let store = MemoryStore::new(dir.path().join("memories"));
303 let scoped = SessionScopedMemory::new(store, String::new());
304 let m = Memory {
305 key: "k".to_string(),
306 content: "c".to_string(),
307 source_session: String::new(),
308 created_at: now_rfc3339(),
309 updated_at: now_rfc3339(),
310 relevance_score: 0.1,
311 };
312 assert!(scoped.save(&m).is_err());
313 }
314
315 #[test]
316 fn test_get_relevant_empty_query_returns_empty() {
317 let dir = TempDir::new().unwrap();
318 let store = MemoryStore::new(dir.path().join("memories"));
319 let scoped = SessionScopedMemory::new(store, "s".to_string());
320 let out = scoped.get_relevant(" ", 10).unwrap();
321 assert!(out.is_empty());
322 }
323
324 #[test]
325 fn test_sanitize_and_validate_key_edge_cases() {
326 assert_eq!(sanitize_key(""), "");
327 assert_eq!(sanitize_key("a@b"), "ab");
328 assert!(validate_key("valid.key-1_").is_ok());
329 let empty_content = sanitize_content("");
330 assert!(empty_content.is_empty());
331 let big = "x".repeat(MAX_CONTENT_BYTES + 10_000);
332 let capped = sanitize_content(&big);
333 assert!(capped.len() <= MAX_CONTENT_BYTES);
334 }
335}