agents_toolkit/
filesystem.rs

1use std::collections::BTreeMap;
2use std::sync::{Arc, RwLock};
3
4use agents_core::agent::{ToolHandle, ToolResponse};
5use agents_core::command::{Command, StateDiff};
6use agents_core::messaging::{AgentMessage, MessageContent, MessageRole, ToolInvocation};
7use agents_core::state::AgentStateSnapshot;
8use async_trait::async_trait;
9use serde::Deserialize;
10
11use crate::{metadata_from, tool_text_response};
12
13#[derive(Clone)]
14pub struct LsTool {
15    pub name: String,
16    pub state: Arc<RwLock<AgentStateSnapshot>>,
17}
18
19#[async_trait]
20impl ToolHandle for LsTool {
21    fn name(&self) -> &str {
22        &self.name
23    }
24
25    async fn invoke(&self, invocation: ToolInvocation) -> anyhow::Result<ToolResponse> {
26        let state = self.state.read().expect("filesystem read lock poisoned");
27        let files: Vec<String> = state.files.keys().cloned().collect();
28        Ok(ToolResponse::Message(AgentMessage {
29            role: MessageRole::Tool,
30            content: MessageContent::Json(serde_json::json!(files)),
31            metadata: metadata_from(&invocation),
32        }))
33    }
34}
35
36#[derive(Clone)]
37pub struct ReadFileTool {
38    pub name: String,
39    pub state: Arc<RwLock<AgentStateSnapshot>>,
40}
41
42#[derive(Debug, Deserialize)]
43struct ReadFileArgs {
44    #[serde(rename = "file_path")]
45    path: String,
46    #[serde(default)]
47    offset: usize,
48    #[serde(default = "default_limit")]
49    limit: usize,
50}
51
52const fn default_limit() -> usize {
53    2000
54}
55
56#[async_trait]
57impl ToolHandle for ReadFileTool {
58    fn name(&self) -> &str {
59        &self.name
60    }
61
62    async fn invoke(&self, invocation: ToolInvocation) -> anyhow::Result<ToolResponse> {
63        let args: ReadFileArgs = serde_json::from_value(invocation.args.clone())?;
64        let state = self.state.read().expect("filesystem read lock poisoned");
65
66        let Some(contents) = state.files.get(&args.path) else {
67            return Ok(tool_text_response(
68                &invocation,
69                format!("Error: File '{}' not found", args.path),
70            ));
71        };
72
73        if contents.trim().is_empty() {
74            return Ok(tool_text_response(
75                &invocation,
76                "System reminder: File exists but has empty contents",
77            ));
78        }
79
80        let lines: Vec<&str> = contents.lines().collect();
81        if args.offset >= lines.len() {
82            return Ok(tool_text_response(
83                &invocation,
84                format!(
85                    "Error: Line offset {} exceeds file length ({} lines)",
86                    args.offset,
87                    lines.len()
88                ),
89            ));
90        }
91
92        let end = (args.offset + args.limit).min(lines.len());
93        let mut formatted = String::new();
94        for (idx, line) in lines[args.offset..end].iter().enumerate() {
95            let line_number = args.offset + idx + 1;
96            let mut content = line.to_string();
97            if content.len() > 2000 {
98                content.truncate(2000);
99            }
100            formatted.push_str(&format!("{line_number:6}\t{content}\n"));
101        }
102
103        Ok(tool_text_response(
104            &invocation,
105            formatted.trim_end().to_string(),
106        ))
107    }
108}
109
110#[derive(Clone)]
111pub struct WriteFileTool {
112    pub name: String,
113    pub state: Arc<RwLock<AgentStateSnapshot>>,
114}
115
116#[derive(Debug, Deserialize)]
117struct WriteFileArgs {
118    #[serde(rename = "file_path")]
119    path: String,
120    content: String,
121}
122
123#[async_trait]
124impl ToolHandle for WriteFileTool {
125    fn name(&self) -> &str {
126        &self.name
127    }
128
129    async fn invoke(&self, invocation: ToolInvocation) -> anyhow::Result<ToolResponse> {
130        let args: WriteFileArgs = serde_json::from_value(invocation.args.clone())?;
131        let mut state = self.state.write().expect("filesystem write lock poisoned");
132        state.files.insert(args.path.clone(), args.content.clone());
133
134        let mut diff = StateDiff::default();
135        let mut files = BTreeMap::new();
136        files.insert(args.path.clone(), args.content);
137        diff.files = Some(files);
138
139        let command = Command {
140            state: diff,
141            messages: vec![AgentMessage {
142                role: MessageRole::Tool,
143                content: MessageContent::Text(format!("Updated file {}", args.path)),
144                metadata: metadata_from(&invocation),
145            }],
146        };
147
148        Ok(ToolResponse::Command(command))
149    }
150}
151
152#[cfg(test)]
153mod tests {
154    use super::*;
155    use agents_core::command::Command;
156    use agents_core::messaging::{MessageContent, MessageRole, ToolInvocation};
157    use agents_core::state::AgentStateSnapshot;
158    use serde_json::json;
159
160    fn shared_state_with_file(path: &str, content: &str) -> Arc<RwLock<AgentStateSnapshot>> {
161        let mut snapshot = AgentStateSnapshot::default();
162        snapshot.files.insert(path.to_string(), content.to_string());
163        Arc::new(RwLock::new(snapshot))
164    }
165
166    #[tokio::test]
167    async fn ls_tool_lists_files() {
168        let state = shared_state_with_file("notes.txt", "Hello");
169        let tool = LsTool {
170            name: "ls".to_string(),
171            state: state.clone(),
172        };
173        let invocation = ToolInvocation {
174            tool_name: "ls".into(),
175            args: serde_json::Value::Null,
176            tool_call_id: Some("call-1".into()),
177        };
178
179        let response = tool.invoke(invocation).await.unwrap();
180        match response {
181            ToolResponse::Message(msg) => {
182                assert_eq!(msg.metadata.unwrap().tool_call_id.unwrap(), "call-1");
183                assert!(matches!(msg.role, MessageRole::Tool));
184                match msg.content {
185                    MessageContent::Json(value) => {
186                        assert_eq!(value, json!(["notes.txt"]));
187                    }
188                    other => panic!("expected json, got {other:?}"),
189                }
190            }
191            _ => panic!("expected message"),
192        }
193    }
194
195    #[tokio::test]
196    async fn read_file_returns_formatted_content() {
197        let state = shared_state_with_file("main.rs", "fn main() {}\nprintln!(\"hi\");");
198        let tool = ReadFileTool {
199            name: "read_file".into(),
200            state,
201        };
202        let invocation = ToolInvocation {
203            tool_name: "read_file".into(),
204            args: json!({ "file_path": "main.rs", "offset": 0, "limit": 10 }),
205            tool_call_id: Some("call-2".into()),
206        };
207
208        let response = tool.invoke(invocation).await.unwrap();
209        match response {
210            ToolResponse::Message(msg) => match msg.content {
211                MessageContent::Text(text) => assert!(text.contains("fn main")),
212                other => panic!("expected text, got {other:?}"),
213            },
214            _ => panic!("expected message"),
215        }
216    }
217
218    #[tokio::test]
219    async fn write_file_returns_command_with_update() {
220        let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
221        let tool = WriteFileTool {
222            name: "write_file".into(),
223            state: state.clone(),
224        };
225        let invocation = ToolInvocation {
226            tool_name: "write_file".into(),
227            args: json!({ "file_path": "notes.txt", "content": "new" }),
228            tool_call_id: Some("call-3".into()),
229        };
230        let response = tool.invoke(invocation).await.unwrap();
231        match response {
232            ToolResponse::Command(Command { state, messages }) => {
233                assert!(state.files.unwrap().contains_key("notes.txt"));
234                assert_eq!(
235                    messages[0]
236                        .metadata
237                        .as_ref()
238                        .unwrap()
239                        .tool_call_id
240                        .as_deref(),
241                    Some("call-3")
242                );
243            }
244            _ => panic!("expected command"),
245        }
246    }
247
248    #[tokio::test]
249    async fn edit_file_missing_returns_error_message() {
250        let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
251        let tool = EditFileTool {
252            name: "edit_file".into(),
253            state,
254        };
255        let invocation = ToolInvocation {
256            tool_name: "edit_file".into(),
257            args: json!({
258                "file_path": "missing.txt",
259                "old_string": "foo",
260                "new_string": "bar"
261            }),
262            tool_call_id: Some("call-4".into()),
263        };
264        let response = tool.invoke(invocation).await.unwrap();
265        match response {
266            ToolResponse::Message(msg) => match msg.content {
267                MessageContent::Text(text) => {
268                    assert!(text.contains("missing.txt"));
269                }
270                other => panic!("expected text, got {other:?}"),
271            },
272            _ => panic!("expected message"),
273        }
274    }
275}
276
277#[derive(Clone)]
278pub struct EditFileTool {
279    pub name: String,
280    pub state: Arc<RwLock<AgentStateSnapshot>>,
281}
282
283#[derive(Debug, Deserialize)]
284struct EditFileArgs {
285    #[serde(rename = "file_path")]
286    path: String,
287    #[serde(rename = "old_string")]
288    old: String,
289    #[serde(rename = "new_string")]
290    new: String,
291    #[serde(default)]
292    replace_all: bool,
293}
294
295#[cfg(test)]
296mod tests_old_location {
297    use super::*;
298    use agents_core::command::Command;
299    use agents_core::messaging::{MessageContent, MessageRole, ToolInvocation};
300    use agents_core::state::AgentStateSnapshot;
301    use serde_json::json;
302
303    fn shared_state_with_file(path: &str, content: &str) -> Arc<RwLock<AgentStateSnapshot>> {
304        let mut snapshot = AgentStateSnapshot::default();
305        snapshot.files.insert(path.to_string(), content.to_string());
306        Arc::new(RwLock::new(snapshot))
307    }
308
309    #[tokio::test]
310    async fn ls_tool_lists_files() {
311        let state = shared_state_with_file("notes.txt", "Hello");
312        let tool = LsTool {
313            name: "ls".to_string(),
314            state: state.clone(),
315        };
316        let invocation = ToolInvocation {
317            tool_name: "ls".into(),
318            args: serde_json::Value::Null,
319            tool_call_id: Some("call-1".into()),
320        };
321
322        let response = tool.invoke(invocation).await.unwrap();
323        match response {
324            ToolResponse::Message(msg) => {
325                assert_eq!(msg.metadata.unwrap().tool_call_id.unwrap(), "call-1");
326                assert!(matches!(msg.role, MessageRole::Tool));
327                match msg.content {
328                    MessageContent::Json(value) => {
329                        assert_eq!(value, json!(["notes.txt"]));
330                    }
331                    other => panic!("expected json, got {other:?}"),
332                }
333            }
334            _ => panic!("expected message"),
335        }
336    }
337
338    #[tokio::test]
339    async fn read_file_returns_formatted_content() {
340        let state = shared_state_with_file("main.rs", "fn main() {}\nprintln!(\"hi\");");
341        let tool = ReadFileTool {
342            name: "read_file".into(),
343            state,
344        };
345        let invocation = ToolInvocation {
346            tool_name: "read_file".into(),
347            args: json!({ "file_path": "main.rs", "offset": 0, "limit": 10 }),
348            tool_call_id: Some("call-2".into()),
349        };
350
351        let response = tool.invoke(invocation).await.unwrap();
352        match response {
353            ToolResponse::Message(msg) => match msg.content {
354                MessageContent::Text(text) => assert!(text.contains("fn main")),
355                other => panic!("expected text, got {other:?}"),
356            },
357            _ => panic!("expected message"),
358        }
359    }
360
361    #[tokio::test]
362    async fn write_file_returns_command_with_update() {
363        let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
364        let tool = WriteFileTool {
365            name: "write_file".into(),
366            state: state.clone(),
367        };
368        let invocation = ToolInvocation {
369            tool_name: "write_file".into(),
370            args: json!({ "file_path": "notes.txt", "content": "new" }),
371            tool_call_id: Some("call-3".into()),
372        };
373        let response = tool.invoke(invocation).await.unwrap();
374        match response {
375            ToolResponse::Command(Command { state, messages }) => {
376                assert!(state.files.unwrap().contains_key("notes.txt"));
377                assert_eq!(
378                    messages[0]
379                        .metadata
380                        .as_ref()
381                        .unwrap()
382                        .tool_call_id
383                        .as_deref(),
384                    Some("call-3")
385                );
386            }
387            _ => panic!("expected command"),
388        }
389    }
390
391    #[tokio::test]
392    async fn edit_file_missing_returns_error_message() {
393        let state = Arc::new(RwLock::new(AgentStateSnapshot::default()));
394        let tool = EditFileTool {
395            name: "edit_file".into(),
396            state,
397        };
398        let invocation = ToolInvocation {
399            tool_name: "edit_file".into(),
400            args: json!({
401                "file_path": "missing.txt",
402                "old_string": "foo",
403                "new_string": "bar"
404            }),
405            tool_call_id: Some("call-4".into()),
406        };
407        let response = tool.invoke(invocation).await.unwrap();
408        match response {
409            ToolResponse::Message(msg) => match msg.content {
410                MessageContent::Text(text) => {
411                    assert!(text.contains("missing.txt"));
412                }
413                other => panic!("expected text, got {other:?}"),
414            },
415            _ => panic!("expected message"),
416        }
417    }
418}
419
420#[async_trait]
421impl ToolHandle for EditFileTool {
422    fn name(&self) -> &str {
423        &self.name
424    }
425
426    async fn invoke(&self, invocation: ToolInvocation) -> anyhow::Result<ToolResponse> {
427        let args: EditFileArgs = serde_json::from_value(invocation.args.clone())?;
428        let mut state = self.state.write().expect("filesystem write lock poisoned");
429
430        let Some(existing) = state.files.get(&args.path).cloned() else {
431            return Ok(tool_text_response(
432                &invocation,
433                format!("Error: File '{}' not found", args.path),
434            ));
435        };
436
437        if !existing.contains(&args.old) {
438            return Ok(tool_text_response(
439                &invocation,
440                format!("Error: String not found in file: '{}'", args.old),
441            ));
442        }
443
444        if !args.replace_all {
445            let occurrences = existing.matches(&args.old).count();
446            if occurrences > 1 {
447                return Ok(tool_text_response(
448                    &invocation,
449                    format!(
450                        "Error: String '{}' appears {} times in file. Use replace_all=true to replace all instances, or provide a more specific string with surrounding context.",
451                        args.old, occurrences
452                    ),
453                ));
454            }
455        }
456
457        let updated = if args.replace_all {
458            existing.replace(&args.old, &args.new)
459        } else {
460            existing.replacen(&args.old, &args.new, 1)
461        };
462
463        let replacement_count = if args.replace_all {
464            existing.matches(&args.old).count()
465        } else {
466            1
467        };
468
469        state.files.insert(args.path.clone(), updated.clone());
470
471        let mut diff = StateDiff::default();
472        let mut files = BTreeMap::new();
473        files.insert(args.path.clone(), updated);
474        diff.files = Some(files);
475
476        let message = if args.replace_all {
477            format!(
478                "Successfully replaced {} instance(s) of the string in '{}'",
479                replacement_count, args.path
480            )
481        } else {
482            format!("Successfully replaced string in '{}'", args.path)
483        };
484
485        let command = Command {
486            state: diff,
487            messages: vec![AgentMessage {
488                role: MessageRole::Tool,
489                content: MessageContent::Text(message),
490                metadata: metadata_from(&invocation),
491            }],
492        };
493
494        Ok(ToolResponse::Command(command))
495    }
496}