use crate::memory::{now_rfc3339, Memory, MemoryStore};
use crate::{PawanError, Result};
use std::cmp::Ordering;
use std::collections::HashSet;
pub const MAX_KEY_CHARS: usize = 256;
pub const MAX_CONTENT_BYTES: usize = 1024 * 1024;
pub struct SessionScopedMemory {
store: MemoryStore,
session_id: String,
}
impl SessionScopedMemory {
pub fn new(store: MemoryStore, session_id: String) -> Self {
Self { store, session_id }
}
fn require_session(&self) -> Result<()> {
if self.session_id.is_empty() {
return Err(PawanError::Config(
"SessionScopedMemory requires a non-empty session_id".to_string(),
));
}
Ok(())
}
pub fn save(&self, memory: &Memory) -> Result<()> {
self.require_session()?;
let mut key = sanitize_key(&memory.key);
validate_key(&key)?;
key = self.disambiguate_key(key)?;
let now = now_rfc3339();
let content = sanitize_content(&memory.content);
let (created_at, relevance_score) = match self.store.load(&key) {
Ok(existing) if existing.source_session == self.session_id => (
existing.created_at,
memory.relevance_score.max(existing.relevance_score),
),
Err(PawanError::NotFound(_)) => (now.clone(), memory.relevance_score),
Ok(_) => {
return Err(PawanError::Tool(
"Memory key conflict after disambiguation; refusing to clobber a foreign session"
.to_string(),
));
}
Err(e) => return Err(e),
};
let to_store = Memory {
key,
content,
source_session: self.session_id.clone(),
created_at,
updated_at: now,
relevance_score,
};
self.store.save(&to_store)
}
pub fn get_relevant(&self, query: &str, limit: usize) -> Result<Vec<Memory>> {
self.require_session()?;
if limit == 0 {
return Ok(vec![]);
}
let pool = limit.saturating_mul(8).clamp(32, 2000);
let mut hits: Vec<Memory> = self
.store
.search(query, pool)?
.into_iter()
.filter(|m| m.source_session == self.session_id || m.is_shared())
.collect();
let mut seen: HashSet<String> = hits.iter().map(|m| m.key.clone()).collect();
if let Ok(keys) = self.store.list() {
for k in keys {
if seen.contains(&k) {
continue;
}
if let Ok(m) = self.store.load(&k) {
if m.is_shared() {
seen.insert(m.key.clone());
hits.push(m);
}
}
}
}
hits.sort_by(|a, b| {
let s = b
.relevance_score
.partial_cmp(&a.relevance_score)
.unwrap_or(Ordering::Equal);
if s != Ordering::Equal {
return s;
}
b.updated_at.cmp(&a.updated_at)
});
hits.truncate(limit);
Ok(hits)
}
pub fn cleanup_session(&self) -> Result<()> {
self.require_session()?;
if !self.store.base_path.exists() {
return Ok(());
}
for entry in std::fs::read_dir(&self.store.base_path)? {
let entry = entry?;
let path = entry.path();
if path.extension().and_then(|s| s.to_str()) != Some("json") {
continue;
}
let bytes = match std::fs::read(&path) {
Ok(b) => b,
Err(_) => continue,
};
let mem: Memory = match serde_json::from_slice(&bytes) {
Ok(m) => m,
Err(_) => continue,
};
if mem.source_session == self.session_id && !mem.is_shared() {
self.store.delete(&mem.key)?;
}
}
Ok(())
}
fn disambiguate_key(&self, base: String) -> Result<String> {
let original = base.clone();
let mut candidate = base;
let mut n = 0u32;
loop {
match self.store.load(&candidate) {
Ok(existing) if existing.source_session == self.session_id => {
return Ok(candidate);
}
Ok(_other) => {
n += 1;
let suffix = format!("__{n}");
let max_base = MAX_KEY_CHARS.saturating_sub(suffix.chars().count());
if max_base == 0 {
return Err(PawanError::Tool(
"Could not reserve space for a disambiguation suffix on the memory key"
.to_string(),
));
}
let truncated = truncate_to_max_chars(&original, max_base);
candidate = format!("{truncated}{suffix}");
}
Err(PawanError::NotFound(_)) => return Ok(candidate),
Err(e) => return Err(e),
}
}
}
}
pub fn sanitize_key(s: &str) -> String {
s.chars()
.filter(|ch| ch.is_ascii_alphanumeric() || *ch == '-' || *ch == '_' || *ch == '.')
.collect()
}
pub fn sanitize_content(s: &str) -> String {
let no_nul: String = s.chars().filter(|&c| c != '\0').collect();
truncate_to_max_bytes(&no_nul, MAX_CONTENT_BYTES)
}
pub fn validate_key(key: &str) -> Result<()> {
if key.is_empty() {
return Err(PawanError::Tool(
"Memory key is empty (or became empty after sanitization)".to_string(),
));
}
if key.chars().count() > MAX_KEY_CHARS {
return Err(PawanError::Tool(format!(
"Memory key exceeds {MAX_KEY_CHARS} characters"
)));
}
if !key
.chars()
.all(|ch| ch.is_ascii_alphanumeric() || ch == '-' || ch == '_' || ch == '.')
{
return Err(PawanError::Tool(
"Memory key contains disallowed characters (allowed: A-Z, a-z, 0-9, -, _, .)"
.to_string(),
));
}
Ok(())
}
fn truncate_to_max_bytes(s: &str, max: usize) -> String {
if s.len() <= max {
return s.to_string();
}
let mut end = max;
while end > 0 && !s.is_char_boundary(end) {
end -= 1;
}
s[..end].to_string()
}
fn truncate_to_max_chars(s: &str, max_chars: usize) -> String {
if s.chars().count() <= max_chars {
return s.to_string();
}
s.chars().take(max_chars).collect()
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
#[test]
fn sanitize_strips_unsafe_key_chars() {
assert_eq!(sanitize_key("a/b@x#y"), "abxy");
assert_eq!(sanitize_key("arch.module-name"), "arch.module-name");
}
#[test]
fn validate_key_rejects_bad_keys() {
assert!(validate_key("").is_err());
assert!(validate_key("bad/key").is_err());
let long: String = "a".repeat(MAX_KEY_CHARS + 1);
assert!(validate_key(&long).is_err());
}
#[test]
fn sanitize_content_strips_nul_and_truncates() {
let s = "a\0b".repeat(MAX_CONTENT_BYTES);
let out = sanitize_content(&s);
assert!(!out.contains('\0'));
assert!(out.len() <= MAX_CONTENT_BYTES);
}
#[test]
fn session_fence_filters_foreign_session() {
let dir = TempDir::new().unwrap();
let store = MemoryStore::new(dir.path().join("memories"));
let mem_a = Memory {
key: "note.a".to_string(),
content: "local debug for session A".to_string(),
source_session: "sess-a".to_string(),
created_at: now_rfc3339(),
updated_at: now_rfc3339(),
relevance_score: 1.0,
};
let mem_b = Memory {
key: "note.b".to_string(),
content: "Architecture decision: use modules".to_string(),
source_session: "sess-b".to_string(),
created_at: now_rfc3339(),
updated_at: now_rfc3339(),
relevance_score: 1.0,
};
let mem_c = Memory {
key: "note.c".to_string(),
content: "Private session B debug scratchpad".to_string(),
source_session: "sess-b".to_string(),
created_at: now_rfc3339(),
updated_at: now_rfc3339(),
relevance_score: 1.0,
};
store.save(&mem_a).unwrap();
store.save(&mem_b).unwrap();
store.save(&mem_c).unwrap();
let scoped = SessionScopedMemory::new(store, "sess-a".to_string());
let found = scoped.get_relevant("debug", 10).unwrap();
let keys: Vec<_> = found.iter().map(|m| m.key.as_str()).collect();
assert!(keys.contains(&"note.a"));
assert!(keys.contains(&"note.b"));
assert!(!keys.contains(&"note.c"));
}
#[test]
fn test_session_scoped_memory_requires_non_empty_session_id() {
let dir = TempDir::new().unwrap();
let store = MemoryStore::new(dir.path().join("memories"));
let scoped = SessionScopedMemory::new(store, String::new());
let m = Memory {
key: "k".to_string(),
content: "c".to_string(),
source_session: String::new(),
created_at: now_rfc3339(),
updated_at: now_rfc3339(),
relevance_score: 0.1,
};
assert!(scoped.save(&m).is_err());
}
#[test]
fn test_get_relevant_empty_query_returns_empty() {
let dir = TempDir::new().unwrap();
let store = MemoryStore::new(dir.path().join("memories"));
let scoped = SessionScopedMemory::new(store, "s".to_string());
let out = scoped.get_relevant(" ", 10).unwrap();
assert!(out.is_empty());
}
#[test]
fn test_sanitize_and_validate_key_edge_cases() {
assert_eq!(sanitize_key(""), "");
assert_eq!(sanitize_key("a@b"), "ab");
assert!(validate_key("valid.key-1_").is_ok());
let empty_content = sanitize_content("");
assert!(empty_content.is_empty());
let big = "x".repeat(MAX_CONTENT_BYTES + 10_000);
let capped = sanitize_content(&big);
assert!(capped.len() <= MAX_CONTENT_BYTES);
}
}