Skip to main content

koda_core/
memory.rs

1//! Semantic memory: project context injected into the system prompt.
2//!
3//! Memory is stored as human-readable Markdown, loaded from two tiers:
4//!
5//! **Global** (`~/.config/koda/memory.md`):
6//!   User-wide preferences and conventions that apply to all projects.
7//!
8//! **Project-local** (first match wins):
9//!   1. `MEMORY.md`  — Koda native
10//!   2. `CLAUDE.md`  — Claude Code compatibility
11//!   3. `AGENTS.md`  — Code Puppy compatibility
12//!
13//! Both tiers are concatenated and injected into the system prompt.
14//! When Koda writes (auto-memory), it always writes to `MEMORY.md`.
15
16use anyhow::Result;
17use std::path::{Path, PathBuf};
18
19/// Project-local memory files, checked in priority order.
20const PROJECT_MEMORY_FILES: &[&str] = &["MEMORY.md", "CLAUDE.md", "AGENTS.md"];
21
22/// Global memory filename inside `~/.config/koda/`.
23const GLOBAL_MEMORY_FILE: &str = "memory.md";
24
25/// Koda's native project memory filename (used for writes).
26const KODA_MEMORY_FILE: &str = "MEMORY.md";
27
28/// Load memory from both global and project-local sources.
29///
30/// Returns the combined content (global first, then project-local).
31/// Returns an empty string if no memory files exist.
32pub fn load(project_root: &Path) -> Result<String> {
33    let mut parts: Vec<String> = Vec::new();
34
35    // 1. Global memory (~/.config/koda/memory.md)
36    if let Some(global) = load_global()? {
37        tracing::info!("Loaded global memory ({} bytes)", global.len());
38        parts.push(global);
39    }
40
41    // 2. Project-local memory (first match wins)
42    if let Some((filename, content)) = load_project(project_root)? {
43        tracing::info!(
44            "Loaded project memory from {filename} ({} bytes)",
45            content.len()
46        );
47        parts.push(content);
48    } else {
49        tracing::info!("No project memory file found");
50    }
51
52    Ok(parts.join("\n\n"))
53}
54
55/// Write an entry to the project's memory file.
56///
57/// If the entry starts with a `## Heading`, and a section with that
58/// heading already exists in the file, the existing section is
59/// **replaced** (updated in place). Otherwise the entry is appended.
60///
61/// Always targets the active memory file (or `MEMORY.md` if none exists).
62pub fn append(project_root: &Path, entry: &str) -> Result<()> {
63    let target_filename =
64        active_project_file(project_root).unwrap_or_else(|| KODA_MEMORY_FILE.to_string());
65    let path = project_root.join(&target_filename);
66    write_or_replace_section(&path, entry)?;
67    tracing::info!("Wrote to {target_filename}: {entry}");
68    Ok(())
69}
70
71/// Return which project memory file is active (for display purposes).
72pub fn active_project_file(project_root: &Path) -> Option<String> {
73    for filename in PROJECT_MEMORY_FILES {
74        if project_root.join(filename).exists() {
75            return Some(filename.to_string());
76        }
77    }
78    None
79}
80
81/// Write an entry to the global memory file (~/.config/koda/memory.md).
82///
83/// If the entry starts with a `## Heading` that already exists, the
84/// section is replaced. Otherwise the entry is appended.
85pub fn append_global(entry: &str) -> Result<()> {
86    let path = global_memory_path()
87        .ok_or_else(|| anyhow::anyhow!("Cannot determine home directory for global memory"))?;
88    if let Some(parent) = path.parent() {
89        std::fs::create_dir_all(parent)?;
90    }
91    write_or_replace_section(&path, entry)?;
92    tracing::info!("Wrote to global memory: {entry}");
93    Ok(())
94}
95
96// ── Internal helpers ──────────────────────────────────────────────────────
97
98/// Write an entry to a memory file, merging by `## Heading` if possible.
99///
100/// If `entry` starts with `## <heading>`, we look for an existing section
101/// with the same heading. If found, the old section (heading through to
102/// the next `##` heading or EOF) is replaced with `entry`. If not found
103/// (or `entry` has no heading), the entry is appended.
104fn write_or_replace_section(path: &Path, entry: &str) -> Result<()> {
105    let heading = extract_heading(entry);
106    let existing = if path.exists() {
107        std::fs::read_to_string(path)?
108    } else {
109        String::new()
110    };
111
112    let new_content = match heading {
113        Some(ref h) if section_exists(&existing, h) => replace_section(&existing, h, entry),
114        _ => {
115            // No heading or heading not found → append
116            let mut buf = existing;
117            if !buf.is_empty() && !buf.ends_with('\n') {
118                buf.push('\n');
119            }
120            buf.push_str(&format!("\n- {entry}"));
121            buf.push('\n');
122            buf
123        }
124    };
125
126    std::fs::write(path, new_content)?;
127    Ok(())
128}
129
130/// Extract a `## Heading` from the first line of an entry.
131fn extract_heading(entry: &str) -> Option<String> {
132    let first_line = entry.lines().next()?.trim();
133    if first_line.starts_with("## ") {
134        Some(first_line.to_string())
135    } else {
136        None
137    }
138}
139
140/// Check if a `## Heading` section already exists in the content.
141fn section_exists(content: &str, heading: &str) -> bool {
142    content.lines().any(|line| line.trim() == heading)
143}
144
145/// Replace a `## Heading` section with new content.
146///
147/// The section spans from the heading line to (but not including)
148/// the next `## ` heading or EOF.
149fn replace_section(content: &str, heading: &str, replacement: &str) -> String {
150    let mut result = String::new();
151    let mut in_target_section = false;
152    let mut replaced = false;
153
154    for line in content.lines() {
155        let trimmed = line.trim();
156
157        if trimmed == heading && !replaced {
158            // Start of the section we want to replace
159            in_target_section = true;
160            // Emit the replacement content
161            result.push_str(replacement);
162            if !replacement.ends_with('\n') {
163                result.push('\n');
164            }
165            replaced = true;
166            continue;
167        }
168
169        if in_target_section {
170            // Check if we've hit the next section heading
171            if trimmed.starts_with("## ") {
172                in_target_section = false;
173                result.push_str(line);
174                result.push('\n');
175            }
176            // else: skip old section content
177            continue;
178        }
179
180        result.push_str(line);
181        result.push('\n');
182    }
183
184    result
185}
186
187/// Load global memory from `~/.config/koda/memory.md`.
188fn load_global() -> Result<Option<String>> {
189    let path = global_memory_path();
190    match path {
191        Some(p) if p.exists() => {
192            let content = std::fs::read_to_string(&p)?;
193            if content.trim().is_empty() {
194                Ok(None)
195            } else {
196                Ok(Some(content))
197            }
198        }
199        _ => Ok(None),
200    }
201}
202
203/// Load project-local memory (first matching file wins).
204fn load_project(project_root: &Path) -> Result<Option<(String, String)>> {
205    for filename in PROJECT_MEMORY_FILES {
206        let path = project_root.join(filename);
207        if path.exists() {
208            let content = std::fs::read_to_string(&path)?;
209            if !content.trim().is_empty() {
210                return Ok(Some((filename.to_string(), content)));
211            }
212        }
213    }
214    Ok(None)
215}
216
217/// Path to the global memory file.
218fn global_memory_path() -> Option<PathBuf> {
219    let home = std::env::var("HOME")
220        .or_else(|_| std::env::var("USERPROFILE"))
221        .ok()?;
222    Some(
223        PathBuf::from(home)
224            .join(".config")
225            .join("koda")
226            .join(GLOBAL_MEMORY_FILE),
227    )
228}
229
230#[cfg(test)]
231mod tests {
232    use super::*;
233    use tempfile::TempDir;
234
235    #[test]
236    fn test_load_missing_memory_returns_empty() {
237        let tmp = TempDir::new().unwrap();
238        let content = load(tmp.path()).unwrap();
239        assert!(content.is_empty());
240    }
241
242    #[test]
243    fn test_load_memory_md() {
244        let tmp = TempDir::new().unwrap();
245        std::fs::write(tmp.path().join("MEMORY.md"), "# Project notes\n- Uses Rust").unwrap();
246        let content = load(tmp.path()).unwrap();
247        assert!(content.contains("Uses Rust"));
248    }
249
250    #[test]
251    fn test_load_claude_md_compat() {
252        let tmp = TempDir::new().unwrap();
253        std::fs::write(tmp.path().join("CLAUDE.md"), "# Claude rules\n- Be concise").unwrap();
254        let content = load(tmp.path()).unwrap();
255        assert!(content.contains("Be concise"));
256    }
257
258    #[test]
259    fn test_load_agents_md_compat() {
260        let tmp = TempDir::new().unwrap();
261        std::fs::write(tmp.path().join("AGENTS.md"), "# Agent rules\n- DRY").unwrap();
262        let content = load(tmp.path()).unwrap();
263        assert!(content.contains("DRY"));
264    }
265
266    #[test]
267    fn test_memory_md_takes_priority_over_claude_md() {
268        let tmp = TempDir::new().unwrap();
269        std::fs::write(tmp.path().join("MEMORY.md"), "koda-memory").unwrap();
270        std::fs::write(tmp.path().join("CLAUDE.md"), "claude-rules").unwrap();
271        let content = load(tmp.path()).unwrap();
272        assert!(content.contains("koda-memory"));
273        assert!(!content.contains("claude-rules"));
274    }
275
276    #[test]
277    fn test_claude_md_takes_priority_over_agents_md() {
278        let tmp = TempDir::new().unwrap();
279        std::fs::write(tmp.path().join("CLAUDE.md"), "claude-rules").unwrap();
280        std::fs::write(tmp.path().join("AGENTS.md"), "puppy-rules").unwrap();
281        let content = load(tmp.path()).unwrap();
282        assert!(content.contains("claude-rules"));
283        assert!(!content.contains("puppy-rules"));
284    }
285
286    #[test]
287    fn test_append_creates_and_appends() {
288        let tmp = TempDir::new().unwrap();
289        append(tmp.path(), "first entry").unwrap();
290        append(tmp.path(), "second entry").unwrap();
291
292        let content = load(tmp.path()).unwrap();
293        assert!(content.contains("first entry"));
294        assert!(content.contains("second entry"));
295    }
296
297    #[test]
298    fn test_append_writes_to_active_file() {
299        let tmp = TempDir::new().unwrap();
300        // If CLAUDE.md exists, append writes directly to CLAUDE.md
301        std::fs::write(tmp.path().join("CLAUDE.md"), "existing claude rules").unwrap();
302        append(tmp.path(), "new koda insight").unwrap();
303
304        // It should NOT create MEMORY.md
305        assert!(!tmp.path().join("MEMORY.md").exists());
306
307        // It SHOULD append to CLAUDE.md
308        let memory = std::fs::read_to_string(tmp.path().join("CLAUDE.md")).unwrap();
309        assert!(memory.contains("new koda insight"));
310    }
311
312    #[test]
313    fn test_active_project_file() {
314        let tmp = TempDir::new().unwrap();
315        assert_eq!(active_project_file(tmp.path()), None);
316
317        std::fs::write(tmp.path().join("AGENTS.md"), "rules").unwrap();
318        assert_eq!(
319            active_project_file(tmp.path()),
320            Some("AGENTS.md".to_string())
321        );
322
323        std::fs::write(tmp.path().join("MEMORY.md"), "memory").unwrap();
324        assert_eq!(
325            active_project_file(tmp.path()),
326            Some("MEMORY.md".to_string())
327        );
328    }
329
330    // ── Section merge tests (#519) ──
331
332    #[test]
333    fn test_extract_heading() {
334        assert_eq!(
335            extract_heading("## Workflow Preferences\n- item"),
336            Some("## Workflow Preferences".to_string())
337        );
338        assert_eq!(extract_heading("just a plain note"), None);
339        assert_eq!(extract_heading("# Top level heading"), None); // only ## is matched
340        assert_eq!(extract_heading(""), None);
341    }
342
343    #[test]
344    fn test_section_exists() {
345        let content = "# Title\n## Workflow Preferences\n- item1\n## Other\n- item2";
346        assert!(section_exists(content, "## Workflow Preferences"));
347        assert!(section_exists(content, "## Other"));
348        assert!(!section_exists(content, "## Missing"));
349    }
350
351    #[test]
352    fn test_replace_section() {
353        let content = "# Title\n## Workflow Preferences\n- old item1\n- old item2\n## Other Section\n- keep this\n";
354        let replacement = "## Workflow Preferences\n- new item1\n- new item2\n- new item3";
355        let result = replace_section(content, "## Workflow Preferences", replacement);
356        assert!(result.contains("- new item1"), "Should contain new content");
357        assert!(result.contains("- new item3"), "Should contain new content");
358        assert!(
359            !result.contains("- old item1"),
360            "Should not contain old content"
361        );
362        assert!(
363            result.contains("## Other Section"),
364            "Should preserve other sections"
365        );
366        assert!(
367            result.contains("- keep this"),
368            "Should preserve other section content"
369        );
370    }
371
372    #[test]
373    fn test_replace_section_at_end() {
374        let content = "## First\n- a\n## Second\n- old\n";
375        let replacement = "## Second\n- new";
376        let result = replace_section(content, "## Second", replacement);
377        assert!(result.contains("## First"), "Should preserve first section");
378        assert!(
379            result.contains("- a"),
380            "Should preserve first section content"
381        );
382        assert!(result.contains("- new"), "Should contain replacement");
383        assert!(!result.contains("- old"), "Should not contain old content");
384    }
385
386    #[test]
387    fn test_append_merges_existing_section() {
388        let tmp = TempDir::new().unwrap();
389        let existing = "## Workflow Preferences\n- old item\n";
390        std::fs::write(tmp.path().join("MEMORY.md"), existing).unwrap();
391
392        append(
393            tmp.path(),
394            "## Workflow Preferences\n- updated item\n- new item",
395        )
396        .unwrap();
397
398        let content = std::fs::read_to_string(tmp.path().join("MEMORY.md")).unwrap();
399        assert!(
400            content.contains("- updated item"),
401            "Should contain new content"
402        );
403        assert!(content.contains("- new item"), "Should contain new content");
404        assert!(
405            !content.contains("- old item"),
406            "Should not contain old content"
407        );
408        // Should only have one copy of the heading
409        assert_eq!(
410            content.matches("## Workflow Preferences").count(),
411            1,
412            "Should have exactly one copy of the heading"
413        );
414    }
415
416    #[test]
417    fn test_append_new_section_still_appends() {
418        let tmp = TempDir::new().unwrap();
419        let existing = "## Existing Section\n- item\n";
420        std::fs::write(tmp.path().join("MEMORY.md"), existing).unwrap();
421
422        append(tmp.path(), "## New Section\n- new item").unwrap();
423
424        let content = std::fs::read_to_string(tmp.path().join("MEMORY.md")).unwrap();
425        assert!(content.contains("## Existing Section"));
426        assert!(content.contains("## New Section"));
427        assert!(content.contains("- new item"));
428    }
429
430    #[test]
431    fn test_append_plain_entry_still_appends() {
432        let tmp = TempDir::new().unwrap();
433        append(tmp.path(), "just a plain note").unwrap();
434        append(tmp.path(), "another plain note").unwrap();
435
436        let content = std::fs::read_to_string(tmp.path().join("MEMORY.md")).unwrap();
437        assert!(content.contains("just a plain note"));
438        assert!(content.contains("another plain note"));
439    }
440}