Skip to main content

imp_core/tools/
memory.rs

1use async_trait::async_trait;
2use serde_json::json;
3
4use super::{Tool, ToolContext, ToolOutput};
5use crate::error::Result;
6use crate::memory::MemoryStore;
7use crate::storage;
8use crate::trust::RiskLabel;
9
10const DEFAULT_MEMORY_LIMIT: usize = 2200;
11const DEFAULT_USER_LIMIT: usize = 1400;
12
13pub struct MemoryTool;
14
15#[async_trait]
16impl Tool for MemoryTool {
17    fn name(&self) -> &str {
18        "memory"
19    }
20
21    fn label(&self) -> &str {
22        "Memory"
23    }
24
25    fn description(&self) -> &str {
26        "Manage persistent memory across sessions. Use to save environment facts, \
27         user preferences, and lessons learned. Target 'memory' for agent notes, \
28         'user' for user profile."
29    }
30
31    fn parameters(&self) -> serde_json::Value {
32        json!({
33            "type": "object",
34            "required": ["action", "target"],
35            "properties": {
36                "action": {
37                    "type": "string",
38                    "enum": ["add", "replace", "remove"],
39                    "description": "Action to perform"
40                },
41                "target": {
42                    "type": "string",
43                    "enum": ["memory", "user"],
44                    "description": "Which store: 'memory' for agent notes, 'user' for user profile"
45                },
46                "content": {
47                    "type": "string",
48                    "description": "Content to add or replacement text"
49                },
50                "old_text": {
51                    "type": "string",
52                    "description": "Unique substring identifying the entry to replace or remove"
53                }
54            }
55        })
56    }
57
58    fn is_readonly(&self) -> bool {
59        false
60    }
61
62    async fn execute(
63        &self,
64        _call_id: &str,
65        params: serde_json::Value,
66        ctx: ToolContext,
67    ) -> Result<ToolOutput> {
68        let action = params["action"].as_str().unwrap_or("");
69        let target = params["target"].as_str().unwrap_or("");
70
71        if action.is_empty() {
72            return Ok(ToolOutput::error("Missing required parameter: action"));
73        }
74        if target.is_empty() {
75            return Ok(ToolOutput::error("Missing required parameter: target"));
76        }
77
78        let (path, char_limit) = match target {
79            "memory" => (storage::global_memory_path(), DEFAULT_MEMORY_LIMIT),
80            "user" => (storage::global_user_path(), DEFAULT_USER_LIMIT),
81            other => {
82                return Ok(ToolOutput::error(format!(
83                    "Unknown target \"{other}\". Use \"memory\" or \"user\"."
84                )));
85            }
86        };
87
88        let mut store = match MemoryStore::load(&path, char_limit) {
89            Ok(s) => s,
90            Err(e) => return Ok(ToolOutput::error(format!("Failed to load memory: {e}"))),
91        };
92
93        let result = match action {
94            "add" => {
95                let content = params["content"].as_str().unwrap_or("");
96                if content.is_empty() {
97                    return Ok(ToolOutput::error(
98                        "Missing required parameter: content (for 'add' action)",
99                    ));
100                }
101                if ctx
102                    .supporting_provenance
103                    .iter()
104                    .any(|provenance| provenance.is_low_trust())
105                {
106                    return Ok(ToolOutput::error(
107                        "Low-trust context cannot be written to durable memory without explicit user adoption.",
108                    ));
109                }
110                if ctx.supporting_provenance.iter().any(|provenance| {
111                    provenance.risk.iter().any(|risk| {
112                        matches!(
113                            risk,
114                            RiskLabel::PossiblePromptInjection | RiskLabel::SecretAdjacent
115                        )
116                    })
117                }) {
118                    return Ok(ToolOutput::error(
119                        "Risk-labeled context cannot be written to durable memory without review.",
120                    ));
121                }
122                store.add(content)?
123            }
124            "replace" => {
125                let old_text = params["old_text"].as_str().unwrap_or("");
126                let content = params["content"].as_str().unwrap_or("");
127                if old_text.is_empty() {
128                    return Ok(ToolOutput::error(
129                        "Missing required parameter: old_text (for 'replace' action)",
130                    ));
131                }
132                if content.is_empty() {
133                    return Ok(ToolOutput::error(
134                        "Missing required parameter: content (for 'replace' action)",
135                    ));
136                }
137                if ctx
138                    .supporting_provenance
139                    .iter()
140                    .any(|provenance| provenance.is_low_trust())
141                {
142                    return Ok(ToolOutput::error(
143                        "Low-trust context cannot replace durable memory without explicit user adoption.",
144                    ));
145                }
146                store.replace(old_text, content)?
147            }
148            "remove" => {
149                let old_text = params["old_text"].as_str().unwrap_or("");
150                if old_text.is_empty() {
151                    return Ok(ToolOutput::error(
152                        "Missing required parameter: old_text (for 'remove' action)",
153                    ));
154                }
155                store.remove(old_text)?
156            }
157            other => {
158                return Ok(ToolOutput::error(format!(
159                    "Unknown action \"{other}\". Use \"add\", \"replace\", or \"remove\"."
160                )));
161            }
162        };
163
164        let json_text = serde_json::to_string_pretty(&result.to_json()).unwrap_or_default();
165        if result.success {
166            Ok(ToolOutput::text(json_text))
167        } else {
168            Ok(ToolOutput::error(json_text))
169        }
170    }
171}
172
173#[cfg(test)]
174mod tests {
175    use super::*;
176    use crate::tools::ToolContext;
177    use std::sync::Arc;
178
179    fn test_ctx() -> ToolContext {
180        let (tx, _rx) = tokio::sync::mpsc::channel(16);
181        let (cmd_tx, _cmd_rx) = tokio::sync::mpsc::channel(16);
182        let dir = std::env::temp_dir();
183        ToolContext {
184            cwd: dir,
185            cancelled: Arc::new(std::sync::atomic::AtomicBool::new(false)),
186            update_tx: tx,
187            command_tx: cmd_tx,
188            ui: Arc::new(crate::ui::NullInterface),
189            file_cache: Arc::new(crate::tools::FileCache::new()),
190            checkpoint_state: Arc::new(crate::tools::CheckpointState::new()),
191            file_tracker: Arc::new(std::sync::Mutex::new(crate::tools::FileTracker::new())),
192            anchor_store: Arc::new(crate::tools::AnchorStore::new()),
193            lua_tool_loader: None,
194            mode: crate::config::AgentMode::Full,
195            read_max_lines: 500,
196            turn_mana_review: Arc::new(std::sync::Mutex::new(
197                crate::mana_review::TurnManaReviewAccumulator::default(),
198            )),
199            config: Arc::new(crate::config::Config::default()),
200            run_policy: Default::default(),
201            supporting_provenance: Vec::new(),
202        }
203    }
204
205    #[tokio::test]
206    async fn memory_tool_validates_params() {
207        let tool = MemoryTool;
208
209        // Missing action
210        let r = tool
211            .execute("c1", json!({"target": "memory"}), test_ctx())
212            .await
213            .unwrap();
214        assert!(r.is_error);
215
216        // Missing target
217        let r = tool
218            .execute("c2", json!({"action": "add"}), test_ctx())
219            .await
220            .unwrap();
221        assert!(r.is_error);
222
223        // Missing content for add
224        let r = tool
225            .execute(
226                "c3",
227                json!({"action": "add", "target": "memory"}),
228                test_ctx(),
229            )
230            .await
231            .unwrap();
232        assert!(r.is_error);
233    }
234    #[tokio::test]
235    async fn memory_tool_blocks_low_trust_durable_writes() {
236        let tool = MemoryTool;
237        let mut ctx = test_ctx();
238        ctx.supporting_provenance
239            .push(crate::trust::Provenance::external_web(
240                "https://example.com",
241            ));
242
243        let result = tool
244            .execute(
245                "c4",
246                json!({"action": "add", "target": "memory", "content": "remember this"}),
247                ctx,
248            )
249            .await
250            .unwrap();
251
252        assert!(result.is_error);
253        assert!(result.content.iter().any(|block| {
254            matches!(block, imp_llm::ContentBlock::Text { text } if text.contains("Low-trust context"))
255        }));
256    }
257}