use anyhow::Result;
use std::collections::HashMap;
use std::path::{Path, PathBuf};
use std::sync::{Mutex, OnceLock};
use std::time::SystemTime;
const PROJECT_MEMORY_FILES: &[&str] = &["MEMORY.md", "CLAUDE.md", "AGENTS.md"];
const GLOBAL_MEMORY_FILE: &str = "memory.md";
const KODA_MEMORY_FILE: &str = "MEMORY.md";
struct CachedEntry {
mtime: SystemTime,
len: u64,
content: String,
}
static MEMORY_CACHE: OnceLock<Mutex<HashMap<PathBuf, CachedEntry>>> = OnceLock::new();
fn cache() -> &'static Mutex<HashMap<PathBuf, CachedEntry>> {
MEMORY_CACHE.get_or_init(|| Mutex::new(HashMap::new()))
}
fn read_through_cache(path: &Path) -> Result<(String, bool)> {
let meta = std::fs::metadata(path)?;
let mtime = meta.modified()?;
let len = meta.len();
let canonical = path.canonicalize().unwrap_or_else(|_| path.to_path_buf());
{
let map = cache().lock().expect("memory cache mutex poisoned");
if let Some(entry) = map.get(&canonical)
&& entry.mtime == mtime
&& entry.len == len
{
tracing::debug!(
path = %canonical.display(),
bytes = entry.content.len(),
"memory cache hit"
);
return Ok((entry.content.clone(), true));
}
}
let content = std::fs::read_to_string(path)?;
{
let mut map = cache().lock().expect("memory cache mutex poisoned");
map.insert(
canonical,
CachedEntry {
mtime,
len,
content: content.clone(),
},
);
}
Ok((content, false))
}
#[cfg(test)]
pub(crate) fn clear_cache_for_tests() {
if let Some(m) = MEMORY_CACHE.get() {
m.lock().expect("memory cache mutex poisoned").clear();
}
}
pub fn load(project_root: &Path) -> Result<String> {
let mut parts: Vec<String> = Vec::new();
if let Some((content, was_hit)) = load_global()? {
if !was_hit {
tracing::info!("Loaded global memory ({} bytes)", content.len());
}
parts.push(content);
}
if let Some((filename, content, was_hit)) = load_project(project_root)? {
if !was_hit {
tracing::info!(
"Loaded project memory from {filename} ({} bytes)",
content.len()
);
}
parts.push(content);
} else {
tracing::info!("No project memory file found");
}
Ok(parts.join("\n\n"))
}
pub fn append(project_root: &Path, entry: &str) -> Result<()> {
let target_filename =
active_project_file(project_root).unwrap_or_else(|| KODA_MEMORY_FILE.to_string());
let path = project_root.join(&target_filename);
write_or_replace_section(&path, entry)?;
tracing::info!("Wrote to {target_filename}: {entry}");
Ok(())
}
pub fn active_project_file(project_root: &Path) -> Option<String> {
for filename in PROJECT_MEMORY_FILES {
if project_root.join(filename).exists() {
return Some(filename.to_string());
}
}
None
}
pub fn append_global(entry: &str) -> Result<()> {
let path = global_memory_path()
.ok_or_else(|| anyhow::anyhow!("Cannot determine home directory for global memory"))?;
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent)?;
}
write_or_replace_section(&path, entry)?;
tracing::info!("Wrote to global memory: {entry}");
Ok(())
}
fn write_or_replace_section(path: &Path, entry: &str) -> Result<()> {
let heading = extract_heading(entry);
let existing = if path.exists() {
std::fs::read_to_string(path)?
} else {
String::new()
};
let new_content = match heading {
Some(ref h) if section_exists(&existing, h) => replace_section(&existing, h, entry),
_ => {
let mut buf = existing;
if !buf.is_empty() && !buf.ends_with('\n') {
buf.push('\n');
}
buf.push_str(&format!("\n- {entry}"));
buf.push('\n');
buf
}
};
std::fs::write(path, new_content)?;
Ok(())
}
fn extract_heading(entry: &str) -> Option<String> {
let first_line = entry.lines().next()?.trim();
if first_line.starts_with("## ") {
Some(first_line.to_string())
} else {
None
}
}
fn section_exists(content: &str, heading: &str) -> bool {
content.lines().any(|line| line.trim() == heading)
}
fn replace_section(content: &str, heading: &str, replacement: &str) -> String {
let mut result = String::new();
let mut in_target_section = false;
let mut replaced = false;
for line in content.lines() {
let trimmed = line.trim();
if trimmed == heading && !replaced {
in_target_section = true;
result.push_str(replacement);
if !replacement.ends_with('\n') {
result.push('\n');
}
replaced = true;
continue;
}
if in_target_section {
if trimmed.starts_with("## ") {
in_target_section = false;
result.push_str(line);
result.push('\n');
}
continue;
}
result.push_str(line);
result.push('\n');
}
result
}
fn load_global() -> Result<Option<(String, bool)>> {
let path = global_memory_path();
match path {
Some(p) if p.exists() => {
let (content, was_hit) = read_through_cache(&p)?;
if content.trim().is_empty() {
Ok(None)
} else {
Ok(Some((content, was_hit)))
}
}
_ => Ok(None),
}
}
fn load_project(project_root: &Path) -> Result<Option<(String, String, bool)>> {
for filename in PROJECT_MEMORY_FILES {
let path = project_root.join(filename);
if path.exists() {
let (content, was_hit) = read_through_cache(&path)?;
if !content.trim().is_empty() {
return Ok(Some((filename.to_string(), content, was_hit)));
}
}
}
Ok(None)
}
fn global_memory_path() -> Option<PathBuf> {
let home = std::env::var("HOME")
.or_else(|_| std::env::var("USERPROFILE"))
.ok()?;
Some(
PathBuf::from(home)
.join(".config")
.join("koda")
.join(GLOBAL_MEMORY_FILE),
)
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
#[test]
fn test_load_missing_memory_returns_empty() {
let tmp = TempDir::new().unwrap();
let content = load(tmp.path()).unwrap();
assert!(content.is_empty());
}
#[test]
fn test_load_memory_md() {
let tmp = TempDir::new().unwrap();
std::fs::write(tmp.path().join("MEMORY.md"), "# Project notes\n- Uses Rust").unwrap();
let content = load(tmp.path()).unwrap();
assert!(content.contains("Uses Rust"));
}
#[test]
fn test_load_claude_md_compat() {
let tmp = TempDir::new().unwrap();
std::fs::write(tmp.path().join("CLAUDE.md"), "# Claude rules\n- Be concise").unwrap();
let content = load(tmp.path()).unwrap();
assert!(content.contains("Be concise"));
}
#[test]
fn test_load_agents_md_compat() {
let tmp = TempDir::new().unwrap();
std::fs::write(tmp.path().join("AGENTS.md"), "# Agent rules\n- DRY").unwrap();
let content = load(tmp.path()).unwrap();
assert!(content.contains("DRY"));
}
#[test]
fn test_memory_md_takes_priority_over_claude_md() {
let tmp = TempDir::new().unwrap();
std::fs::write(tmp.path().join("MEMORY.md"), "koda-memory").unwrap();
std::fs::write(tmp.path().join("CLAUDE.md"), "claude-rules").unwrap();
let content = load(tmp.path()).unwrap();
assert!(content.contains("koda-memory"));
assert!(!content.contains("claude-rules"));
}
#[test]
fn test_claude_md_takes_priority_over_agents_md() {
let tmp = TempDir::new().unwrap();
std::fs::write(tmp.path().join("CLAUDE.md"), "claude-rules").unwrap();
std::fs::write(tmp.path().join("AGENTS.md"), "puppy-rules").unwrap();
let content = load(tmp.path()).unwrap();
assert!(content.contains("claude-rules"));
assert!(!content.contains("puppy-rules"));
}
#[test]
fn test_append_creates_and_appends() {
let tmp = TempDir::new().unwrap();
append(tmp.path(), "first entry").unwrap();
append(tmp.path(), "second entry").unwrap();
let content = load(tmp.path()).unwrap();
assert!(content.contains("first entry"));
assert!(content.contains("second entry"));
}
#[test]
fn test_append_writes_to_active_file() {
let tmp = TempDir::new().unwrap();
std::fs::write(tmp.path().join("CLAUDE.md"), "existing claude rules").unwrap();
append(tmp.path(), "new koda insight").unwrap();
assert!(!tmp.path().join("MEMORY.md").exists());
let memory = std::fs::read_to_string(tmp.path().join("CLAUDE.md")).unwrap();
assert!(memory.contains("new koda insight"));
}
#[test]
fn test_active_project_file() {
let tmp = TempDir::new().unwrap();
assert_eq!(active_project_file(tmp.path()), None);
std::fs::write(tmp.path().join("AGENTS.md"), "rules").unwrap();
assert_eq!(
active_project_file(tmp.path()),
Some("AGENTS.md".to_string())
);
std::fs::write(tmp.path().join("MEMORY.md"), "memory").unwrap();
assert_eq!(
active_project_file(tmp.path()),
Some("MEMORY.md".to_string())
);
}
#[test]
fn test_extract_heading() {
assert_eq!(
extract_heading("## Workflow Preferences\n- item"),
Some("## Workflow Preferences".to_string())
);
assert_eq!(extract_heading("just a plain note"), None);
assert_eq!(extract_heading("# Top level heading"), None); assert_eq!(extract_heading(""), None);
}
#[test]
fn test_section_exists() {
let content = "# Title\n## Workflow Preferences\n- item1\n## Other\n- item2";
assert!(section_exists(content, "## Workflow Preferences"));
assert!(section_exists(content, "## Other"));
assert!(!section_exists(content, "## Missing"));
}
#[test]
fn test_replace_section() {
let content = "# Title\n## Workflow Preferences\n- old item1\n- old item2\n## Other Section\n- keep this\n";
let replacement = "## Workflow Preferences\n- new item1\n- new item2\n- new item3";
let result = replace_section(content, "## Workflow Preferences", replacement);
assert!(result.contains("- new item1"), "Should contain new content");
assert!(result.contains("- new item3"), "Should contain new content");
assert!(
!result.contains("- old item1"),
"Should not contain old content"
);
assert!(
result.contains("## Other Section"),
"Should preserve other sections"
);
assert!(
result.contains("- keep this"),
"Should preserve other section content"
);
}
#[test]
fn test_replace_section_at_end() {
let content = "## First\n- a\n## Second\n- old\n";
let replacement = "## Second\n- new";
let result = replace_section(content, "## Second", replacement);
assert!(result.contains("## First"), "Should preserve first section");
assert!(
result.contains("- a"),
"Should preserve first section content"
);
assert!(result.contains("- new"), "Should contain replacement");
assert!(!result.contains("- old"), "Should not contain old content");
}
#[test]
fn test_append_merges_existing_section() {
let tmp = TempDir::new().unwrap();
let existing = "## Workflow Preferences\n- old item\n";
std::fs::write(tmp.path().join("MEMORY.md"), existing).unwrap();
append(
tmp.path(),
"## Workflow Preferences\n- updated item\n- new item",
)
.unwrap();
let content = std::fs::read_to_string(tmp.path().join("MEMORY.md")).unwrap();
assert!(
content.contains("- updated item"),
"Should contain new content"
);
assert!(content.contains("- new item"), "Should contain new content");
assert!(
!content.contains("- old item"),
"Should not contain old content"
);
assert_eq!(
content.matches("## Workflow Preferences").count(),
1,
"Should have exactly one copy of the heading"
);
}
#[test]
fn test_append_new_section_still_appends() {
let tmp = TempDir::new().unwrap();
let existing = "## Existing Section\n- item\n";
std::fs::write(tmp.path().join("MEMORY.md"), existing).unwrap();
append(tmp.path(), "## New Section\n- new item").unwrap();
let content = std::fs::read_to_string(tmp.path().join("MEMORY.md")).unwrap();
assert!(content.contains("## Existing Section"));
assert!(content.contains("## New Section"));
assert!(content.contains("- new item"));
}
#[test]
fn test_append_plain_entry_still_appends() {
let tmp = TempDir::new().unwrap();
append(tmp.path(), "just a plain note").unwrap();
append(tmp.path(), "another plain note").unwrap();
let content = std::fs::read_to_string(tmp.path().join("MEMORY.md")).unwrap();
assert!(content.contains("just a plain note"));
assert!(content.contains("another plain note"));
}
#[test]
fn test_cache_hit_skips_disk_reread() {
clear_cache_for_tests();
let tmp = TempDir::new().unwrap();
let path = tmp.path().join("CLAUDE.md");
std::fs::write(&path, "# pinned content\n- entry one").unwrap();
let first = load(tmp.path()).unwrap();
assert!(first.contains("entry one"));
for _ in 0..10 {
let again = load(tmp.path()).unwrap();
assert_eq!(again, first, "cache must serve identical bytes");
}
let canonical = path.canonicalize().unwrap();
let map = cache().lock().unwrap();
let entry = map.get(&canonical).expect("path must be cached");
assert_eq!(entry.content, "# pinned content\n- entry one");
assert_eq!(entry.len, std::fs::metadata(&path).unwrap().len());
}
#[test]
fn test_cache_invalidates_on_append() {
clear_cache_for_tests();
let tmp = TempDir::new().unwrap();
std::fs::write(tmp.path().join("MEMORY.md"), "## Initial\n- one\n").unwrap();
let before = load(tmp.path()).unwrap();
assert!(before.contains("- one"));
assert!(!before.contains("- two"));
std::thread::sleep(std::time::Duration::from_millis(1100));
append(tmp.path(), "- two").unwrap();
let after = load(tmp.path()).unwrap();
assert!(
after.contains("- two"),
"post-append load must surface the new entry; got: {after:?}"
);
}
#[test]
fn test_cache_invalidates_on_size_change() {
clear_cache_for_tests();
let tmp = TempDir::new().unwrap();
let path = tmp.path().join("CLAUDE.md");
std::fs::write(&path, "short").unwrap();
let before = load(tmp.path()).unwrap();
assert!(before.contains("short"));
std::fs::write(&path, "this is much longer content than before").unwrap();
let after = load(tmp.path()).unwrap();
assert!(
after.contains("much longer"),
"len-delta must trigger a refresh; got: {after:?}"
);
}
#[test]
fn test_cache_isolates_by_path() {
clear_cache_for_tests();
let proj_a = TempDir::new().unwrap();
let proj_b = TempDir::new().unwrap();
std::fs::write(proj_a.path().join("CLAUDE.md"), "alpha-content").unwrap();
std::fs::write(proj_b.path().join("CLAUDE.md"), "beta-content").unwrap();
let a = load(proj_a.path()).unwrap();
let b = load(proj_b.path()).unwrap();
assert!(a.contains("alpha-content"));
assert!(!a.contains("beta-content"));
assert!(b.contains("beta-content"));
assert!(!b.contains("alpha-content"));
assert_eq!(load(proj_a.path()).unwrap(), a);
assert_eq!(load(proj_b.path()).unwrap(), b);
}
#[test]
fn test_concurrent_loads_share_cache() {
clear_cache_for_tests();
let tmp = TempDir::new().unwrap();
let path = tmp.path().join("CLAUDE.md");
std::fs::write(&path, "shared-by-many").unwrap();
let project_root = tmp.path().to_path_buf();
let handles: Vec<_> = (0..8)
.map(|_| {
let pr = project_root.clone();
std::thread::spawn(move || {
for _ in 0..10 {
let c = load(&pr).unwrap();
assert!(c.contains("shared-by-many"));
}
})
})
.collect();
for h in handles {
h.join().expect("thread panicked");
}
let canonical = path.canonicalize().unwrap();
let map = cache().lock().unwrap();
assert!(
map.contains_key(&canonical),
"cache must contain entry for {canonical:?}"
);
}
}