Skip to main content

koda_core/tools/
memory.rs

1//! Memory tools — read and write semantic memory.
2//!
3//! Exposes `MemoryRead` and `MemoryWrite` as tools the LLM can call
4//! to inspect and persist project/global context.
5//!
6//! ## MemoryRead
7//!
8//! Returns the current contents of project and global memory files.
9//! No parameters required.
10//!
11//! ## MemoryWrite
12//!
13//! Appends a fact to `MEMORY.md` (project) or `~/.config/koda/memory.md` (global).
14//!
15//! - **`content`** (required) — The fact or convention to remember
16//! - **`scope`** (optional, default `"project"`) — `"project"` or `"global"`
17//!
18//! Memory is injected into the system prompt on every turn, so saved facts
19//! persist across sessions and compactions. See [`crate::memory`] for the
20//! file format and loading logic.
21
22use crate::memory;
23use crate::providers::ToolDefinition;
24use anyhow::Result;
25use serde_json::{Value, json};
26use std::path::Path;
27
28/// Return tool definitions for the LLM.
29pub fn definitions() -> Vec<ToolDefinition> {
30    vec![
31        ToolDefinition {
32            name: "MemoryRead".to_string(),
33            description: "Read project and global memory (MEMORY.md + ~/.config/koda/memory.md)."
34                .to_string(),
35            parameters: json!({
36                "type": "object",
37                "properties": {}
38            }),
39        },
40        ToolDefinition {
41            name: "MemoryWrite".to_string(),
42            description: "Save a project insight or rule to persistent memory (MEMORY.md). \
43                Set scope='global' for user-wide preferences (~/.config/koda/memory.md)."
44                .to_string(),
45            parameters: json!({
46                "type": "object",
47                "properties": {
48                    "content": {
49                        "type": "string",
50                        "description": "The insight or rule to remember"
51                    },
52                    "scope": {
53                        "type": "string",
54                        "description": "'project' (default) or 'global'"
55                    }
56                },
57                "required": ["content"]
58            }),
59        },
60    ]
61}
62
63/// Read all loaded memory.
64pub async fn memory_read(project_root: &Path) -> Result<String> {
65    let content = memory::load(project_root)?;
66    if content.is_empty() {
67        return Ok(
68            "No memory stored yet. Use MemoryWrite to save project context or preferences."
69                .to_string(),
70        );
71    }
72
73    let active = memory::active_project_file(project_root);
74    let header = match active {
75        Some(f) => format!("Active project memory file: {f}"),
76        None => "No project memory file (will create MEMORY.md on first write)".to_string(),
77    };
78
79    Ok(format!("{header}\n\n{content}"))
80}
81
82/// Write a memory entry.
83pub async fn memory_write(project_root: &Path, args: &Value) -> Result<String> {
84    let content = args["content"]
85        .as_str()
86        .ok_or_else(|| anyhow::anyhow!("Missing 'content' argument"))?;
87    let scope = args["scope"].as_str().unwrap_or("project");
88
89    match scope {
90        "global" => {
91            memory::append_global(content)?;
92            Ok(format!("Saved to global memory: {content}"))
93        }
94        _ => {
95            memory::append(project_root, content)?;
96            Ok(format!("Saved to project memory: {content}"))
97        }
98    }
99}
100
101#[cfg(test)]
102mod tests {
103    use super::*;
104    use tempfile::TempDir;
105
106    // ── definitions ──────────────────────────────────────────────────────
107
108    #[test]
109    fn test_definitions_returns_two_tools() {
110        let defs = definitions();
111        assert_eq!(defs.len(), 2);
112        let names: Vec<&str> = defs.iter().map(|d| d.name.as_str()).collect();
113        assert!(names.contains(&"MemoryRead"));
114        assert!(names.contains(&"MemoryWrite"));
115    }
116
117    #[test]
118    fn test_memory_write_requires_content() {
119        let write_def = definitions()
120            .into_iter()
121            .find(|d| d.name == "MemoryWrite")
122            .unwrap();
123        let required: Vec<&str> = write_def.parameters["required"]
124            .as_array()
125            .unwrap()
126            .iter()
127            .map(|v| v.as_str().unwrap())
128            .collect();
129        assert!(required.contains(&"content"));
130        assert!(!required.contains(&"scope"), "scope should be optional");
131    }
132
133    #[test]
134    fn test_memory_read_has_no_required_params() {
135        let read_def = definitions()
136            .into_iter()
137            .find(|d| d.name == "MemoryRead")
138            .unwrap();
139        // MemoryRead takes no parameters at all.
140        let props = &read_def.parameters["properties"];
141        assert!(
142            props.as_object().map(|o| o.is_empty()).unwrap_or(true),
143            "MemoryRead should have no properties"
144        );
145    }
146
147    // ── memory_read / memory_write ──────────────────────────────────
148
149    #[tokio::test]
150    async fn test_memory_read_empty() {
151        let tmp = TempDir::new().unwrap();
152        let result = memory_read(tmp.path()).await.unwrap();
153        assert!(result.contains("No memory stored"));
154    }
155
156    #[tokio::test]
157    async fn test_memory_read_with_content() {
158        let tmp = TempDir::new().unwrap();
159        std::fs::write(tmp.path().join("MEMORY.md"), "# Notes\n- Uses Rust").unwrap();
160        let result = memory_read(tmp.path()).await.unwrap();
161        assert!(result.contains("Uses Rust"));
162        assert!(result.contains("MEMORY.md"));
163    }
164
165    #[tokio::test]
166    async fn test_memory_write_project() {
167        let tmp = TempDir::new().unwrap();
168        let args = json!({ "content": "This project uses SQLite" });
169        let result = memory_write(tmp.path(), &args).await.unwrap();
170        assert!(result.contains("project memory"));
171
172        let content = std::fs::read_to_string(tmp.path().join("MEMORY.md")).unwrap();
173        assert!(content.contains("This project uses SQLite"));
174    }
175
176    #[tokio::test]
177    async fn test_memory_write_defaults_to_project() {
178        let tmp = TempDir::new().unwrap();
179        let args = json!({ "content": "no scope specified" });
180        memory_write(tmp.path(), &args).await.unwrap();
181        assert!(tmp.path().join("MEMORY.md").exists());
182    }
183}