Skip to main content

agent_core/controller/tools/
write_file.rs

1//! WriteFile tool implementation
2//!
3//! This tool allows the LLM to write files to the local filesystem.
4//! It integrates with the PermissionRegistry to require user approval
5//! before performing write operations.
6
7use std::collections::HashMap;
8use std::future::Future;
9use std::path::Path;
10use std::pin::Pin;
11use std::sync::Arc;
12
13use tokio::fs;
14
15use super::ask_for_permissions::{PermissionCategory, PermissionRequest};
16use super::permission_registry::PermissionRegistry;
17use super::types::{
18    DisplayConfig, DisplayResult, Executable, ResultContentType, ToolContext, ToolType,
19};
20
21/// WriteFile tool name constant.
22pub const WRITE_FILE_TOOL_NAME: &str = "write_file";
23
24/// WriteFile tool description constant.
25pub const WRITE_FILE_TOOL_DESCRIPTION: &str = r#"Writes content to a file, creating it if it doesn't exist or overwriting if it does.
26
27Usage:
28- The file_path parameter must be an absolute path, not a relative path
29- This tool will overwrite the existing file if there is one at the provided path
30- Parent directories will be created automatically if they don't exist
31- Requires user permission before writing (may be cached for session)
32
33Returns:
34- Success message with bytes written on successful write
35- Error message if permission is denied or the operation fails"#;
36
37/// WriteFile tool JSON schema constant.
38pub const WRITE_FILE_TOOL_SCHEMA: &str = r#"{
39    "type": "object",
40    "properties": {
41        "file_path": {
42            "type": "string",
43            "description": "The absolute path to the file to write"
44        },
45        "content": {
46            "type": "string",
47            "description": "The content to write to the file"
48        },
49        "create_directories": {
50            "type": "boolean",
51            "description": "Whether to create parent directories if they don't exist. Defaults to true."
52        }
53    },
54    "required": ["file_path", "content"]
55}"#;
56
57/// Tool that writes files to the filesystem with permission checks.
58pub struct WriteFileTool {
59    /// Reference to the permission registry for requesting write permissions.
60    permission_registry: Arc<PermissionRegistry>,
61}
62
63impl WriteFileTool {
64    /// Create a new WriteFileTool with the given permission registry.
65    ///
66    /// # Arguments
67    /// * `permission_registry` - The registry used to request and cache permissions.
68    pub fn new(permission_registry: Arc<PermissionRegistry>) -> Self {
69        Self { permission_registry }
70    }
71
72    /// Builds a permission request for writing to a file.
73    fn build_permission_request(
74        file_path: &str,
75        content_len: usize,
76        is_overwrite: bool,
77    ) -> PermissionRequest {
78        let action_verb = if is_overwrite { "Overwrite" } else { "Create" };
79        let filename = Path::new(file_path)
80            .file_name()
81            .and_then(|n| n.to_str())
82            .unwrap_or(file_path);
83
84        PermissionRequest {
85            action: format!("{} file: {}", action_verb, filename),
86            reason: Some(format!(
87                "{} file with {} bytes of content",
88                action_verb.to_lowercase(),
89                content_len
90            )),
91            resources: vec![file_path.to_string()],
92            category: PermissionCategory::FileWrite,
93        }
94    }
95}
96
97impl Executable for WriteFileTool {
98    fn name(&self) -> &str {
99        WRITE_FILE_TOOL_NAME
100    }
101
102    fn description(&self) -> &str {
103        WRITE_FILE_TOOL_DESCRIPTION
104    }
105
106    fn input_schema(&self) -> &str {
107        WRITE_FILE_TOOL_SCHEMA
108    }
109
110    fn tool_type(&self) -> ToolType {
111        ToolType::TextEdit
112    }
113
114    fn execute(
115        &self,
116        context: ToolContext,
117        input: HashMap<String, serde_json::Value>,
118    ) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send>> {
119        let permission_registry = self.permission_registry.clone();
120
121        Box::pin(async move {
122            // ─────────────────────────────────────────────────────────────
123            // Step 1: Extract and validate parameters
124            // ─────────────────────────────────────────────────────────────
125            let file_path = input
126                .get("file_path")
127                .and_then(|v| v.as_str())
128                .ok_or_else(|| "Missing required 'file_path' parameter".to_string())?;
129
130            let content = input
131                .get("content")
132                .and_then(|v| v.as_str())
133                .ok_or_else(|| "Missing required 'content' parameter".to_string())?;
134
135            let create_directories = input
136                .get("create_directories")
137                .and_then(|v| v.as_bool())
138                .unwrap_or(true);
139
140            let path = Path::new(file_path);
141
142            // Validate absolute path
143            if !path.is_absolute() {
144                return Err(format!(
145                    "file_path must be an absolute path, got: {}",
146                    file_path
147                ));
148            }
149
150            // Check if this is an overwrite (file exists) or create (new file)
151            let is_overwrite = path.exists();
152
153            // ─────────────────────────────────────────────────────────────
154            // Step 2: Build permission request
155            // ─────────────────────────────────────────────────────────────
156            let permission_request =
157                Self::build_permission_request(file_path, content.len(), is_overwrite);
158
159            // ─────────────────────────────────────────────────────────────
160            // Step 3: Check if permission is already granted for this session
161            // ─────────────────────────────────────────────────────────────
162            let already_granted = permission_registry
163                .is_granted(context.session_id, &permission_request)
164                .await;
165
166            if !already_granted {
167                // ─────────────────────────────────────────────────────────
168                // Step 4: Request permission from user
169                // This emits ControllerEvent::PermissionRequired to UI
170                // ─────────────────────────────────────────────────────────
171                let response_rx = permission_registry
172                    .register(
173                        context.tool_use_id.clone(),
174                        context.session_id,
175                        permission_request,
176                        context.turn_id.clone(),
177                    )
178                    .await
179                    .map_err(|e| format!("Failed to request permission: {}", e))?;
180
181                // ─────────────────────────────────────────────────────────
182                // Step 5: Block until user responds
183                // ─────────────────────────────────────────────────────────
184                let response = response_rx
185                    .await
186                    .map_err(|_| "Permission request was cancelled".to_string())?;
187
188                // ─────────────────────────────────────────────────────────
189                // Step 6: Check if permission was granted
190                // ─────────────────────────────────────────────────────────
191                if !response.granted {
192                    let reason = response
193                        .message
194                        .unwrap_or_else(|| "Permission denied by user".to_string());
195                    return Err(format!(
196                        "Permission denied to write '{}': {}",
197                        file_path, reason
198                    ));
199                }
200            }
201
202            // ─────────────────────────────────────────────────────────────
203            // Step 7: Create parent directories if requested
204            // ─────────────────────────────────────────────────────────────
205            if create_directories {
206                if let Some(parent) = path.parent() {
207                    if !parent.exists() {
208                        fs::create_dir_all(parent).await.map_err(|e| {
209                            format!("Failed to create parent directories: {}", e)
210                        })?;
211                    }
212                }
213            }
214
215            // ─────────────────────────────────────────────────────────────
216            // Step 8: Perform the write operation
217            // ─────────────────────────────────────────────────────────────
218            let bytes_written = content.len();
219            fs::write(path, content).await.map_err(|e| {
220                format!("Failed to write file '{}': {}", file_path, e)
221            })?;
222
223            let action = if is_overwrite { "overwrote" } else { "created" };
224            Ok(format!(
225                "Successfully {} '{}' ({} bytes)",
226                action, file_path, bytes_written
227            ))
228        })
229    }
230
231    fn display_config(&self) -> DisplayConfig {
232        DisplayConfig {
233            display_name: "Write File".to_string(),
234            display_title: Box::new(|input| {
235                input
236                    .get("file_path")
237                    .and_then(|v| v.as_str())
238                    .map(|p| {
239                        Path::new(p)
240                            .file_name()
241                            .and_then(|n| n.to_str())
242                            .unwrap_or(p)
243                            .to_string()
244                    })
245                    .unwrap_or_default()
246            }),
247            display_content: Box::new(|input, result| {
248                let content_preview = input
249                    .get("content")
250                    .and_then(|v| v.as_str())
251                    .map(|c| {
252                        let lines: Vec<&str> = c.lines().take(10).collect();
253                        if c.lines().count() > 10 {
254                            format!("{}...\n[truncated]", lines.join("\n"))
255                        } else {
256                            lines.join("\n")
257                        }
258                    })
259                    .unwrap_or_else(|| result.to_string());
260
261                DisplayResult {
262                    content: content_preview,
263                    content_type: ResultContentType::PlainText,
264                    is_truncated: input
265                        .get("content")
266                        .and_then(|v| v.as_str())
267                        .map(|c| c.lines().count() > 10)
268                        .unwrap_or(false),
269                    full_length: input
270                        .get("content")
271                        .and_then(|v| v.as_str())
272                        .map(|c| c.lines().count())
273                        .unwrap_or(0),
274                }
275            }),
276        }
277    }
278
279    fn compact_summary(
280        &self,
281        input: &HashMap<String, serde_json::Value>,
282        _result: &str,
283    ) -> String {
284        let filename = input
285            .get("file_path")
286            .and_then(|v| v.as_str())
287            .map(|p| {
288                Path::new(p)
289                    .file_name()
290                    .and_then(|n| n.to_str())
291                    .unwrap_or(p)
292            })
293            .unwrap_or("unknown");
294
295        let bytes = input
296            .get("content")
297            .and_then(|v| v.as_str())
298            .map(|c| c.len())
299            .unwrap_or(0);
300
301        format!("[WriteFile: {} ({} bytes)]", filename, bytes)
302    }
303}
304
305#[cfg(test)]
306mod tests {
307    use super::*;
308    use crate::controller::tools::ask_for_permissions::{PermissionResponse, PermissionScope};
309    use crate::controller::types::ControllerEvent;
310    use tempfile::TempDir;
311    use tokio::sync::mpsc;
312
313    /// Helper to create a permission registry for testing.
314    fn create_test_registry() -> (Arc<PermissionRegistry>, mpsc::Receiver<ControllerEvent>) {
315        let (tx, rx) = mpsc::channel(16);
316        let registry = Arc::new(PermissionRegistry::new(tx));
317        (registry, rx)
318    }
319
320    #[tokio::test]
321    async fn test_write_new_file_with_permission_granted() {
322        let (registry, mut event_rx) = create_test_registry();
323        let tool = WriteFileTool::new(registry.clone());
324        let temp_dir = TempDir::new().unwrap();
325        let file_path = temp_dir.path().join("test.txt");
326
327        let mut input = HashMap::new();
328        input.insert(
329            "file_path".to_string(),
330            serde_json::Value::String(file_path.to_str().unwrap().to_string()),
331        );
332        input.insert(
333            "content".to_string(),
334            serde_json::Value::String("Hello, World!".to_string()),
335        );
336
337        let context = ToolContext {
338            session_id: 1,
339            tool_use_id: "test-123".to_string(),
340            turn_id: None,
341        };
342
343        // Spawn task to handle permission request
344        let registry_clone = registry.clone();
345        tokio::spawn(async move {
346            // Wait for permission request event
347            if let Some(ControllerEvent::PermissionRequired { tool_use_id, .. }) =
348                event_rx.recv().await
349            {
350                // Grant permission
351                registry_clone
352                    .respond(&tool_use_id, PermissionResponse::grant(PermissionScope::Once))
353                    .await
354                    .unwrap();
355            }
356        });
357
358        let result = tool.execute(context, input).await;
359
360        assert!(result.is_ok());
361        assert!(file_path.exists());
362        assert_eq!(
363            tokio::fs::read_to_string(&file_path).await.unwrap(),
364            "Hello, World!"
365        );
366    }
367
368    #[tokio::test]
369    async fn test_write_file_permission_denied() {
370        let (registry, mut event_rx) = create_test_registry();
371        let tool = WriteFileTool::new(registry.clone());
372        let temp_dir = TempDir::new().unwrap();
373        let file_path = temp_dir.path().join("test.txt");
374
375        let mut input = HashMap::new();
376        input.insert(
377            "file_path".to_string(),
378            serde_json::Value::String(file_path.to_str().unwrap().to_string()),
379        );
380        input.insert(
381            "content".to_string(),
382            serde_json::Value::String("Hello, World!".to_string()),
383        );
384
385        let context = ToolContext {
386            session_id: 1,
387            tool_use_id: "test-456".to_string(),
388            turn_id: None,
389        };
390
391        // Spawn task to deny permission
392        let registry_clone = registry.clone();
393        tokio::spawn(async move {
394            if let Some(ControllerEvent::PermissionRequired { tool_use_id, .. }) =
395                event_rx.recv().await
396            {
397                // Deny permission
398                registry_clone
399                    .respond(
400                        &tool_use_id,
401                        PermissionResponse::deny(Some("Not allowed".to_string())),
402                    )
403                    .await
404                    .unwrap();
405            }
406        });
407
408        let result = tool.execute(context, input).await;
409
410        assert!(result.is_err());
411        assert!(result.unwrap_err().contains("Permission denied"));
412        assert!(!file_path.exists());
413    }
414
415    #[tokio::test]
416    async fn test_write_file_session_permission_cached() {
417        let (registry, mut event_rx) = create_test_registry();
418        let tool = WriteFileTool::new(registry.clone());
419        let temp_dir = TempDir::new().unwrap();
420
421        // First write - will request permission
422        let file_path_1 = temp_dir.path().join("test1.txt");
423        let mut input_1 = HashMap::new();
424        input_1.insert(
425            "file_path".to_string(),
426            serde_json::Value::String(file_path_1.to_str().unwrap().to_string()),
427        );
428        input_1.insert(
429            "content".to_string(),
430            serde_json::Value::String("Content 1".to_string()),
431        );
432
433        let context_1 = ToolContext {
434            session_id: 1,
435            tool_use_id: "test-1".to_string(),
436            turn_id: None,
437        };
438
439        // Grant with Session scope
440        let registry_clone = registry.clone();
441        tokio::spawn(async move {
442            if let Some(ControllerEvent::PermissionRequired { tool_use_id, .. }) =
443                event_rx.recv().await
444            {
445                registry_clone
446                    .respond(
447                        &tool_use_id,
448                        PermissionResponse::grant(PermissionScope::Session),
449                    )
450                    .await
451                    .unwrap();
452            }
453        });
454
455        let result_1 = tool.execute(context_1, input_1).await;
456        assert!(result_1.is_ok());
457        assert!(file_path_1.exists());
458
459        // Second write - should use cached permission (no event emitted)
460        // Note: Cache matching uses action pattern, so same action "Create file: test2.txt"
461        // will NOT match "Create file: test1.txt". This is current behavior.
462        // For this test, we verify the first write worked.
463    }
464
465    #[tokio::test]
466    async fn test_overwrite_existing_file() {
467        let (registry, mut event_rx) = create_test_registry();
468        let tool = WriteFileTool::new(registry.clone());
469        let temp_dir = TempDir::new().unwrap();
470        let file_path = temp_dir.path().join("existing.txt");
471
472        // Create existing file
473        tokio::fs::write(&file_path, "old content").await.unwrap();
474
475        let mut input = HashMap::new();
476        input.insert(
477            "file_path".to_string(),
478            serde_json::Value::String(file_path.to_str().unwrap().to_string()),
479        );
480        input.insert(
481            "content".to_string(),
482            serde_json::Value::String("new content".to_string()),
483        );
484
485        let context = ToolContext {
486            session_id: 1,
487            tool_use_id: "test-overwrite".to_string(),
488            turn_id: None,
489        };
490
491        // Grant permission
492        let registry_clone = registry.clone();
493        tokio::spawn(async move {
494            if let Some(ControllerEvent::PermissionRequired { tool_use_id, .. }) =
495                event_rx.recv().await
496            {
497                registry_clone
498                    .respond(&tool_use_id, PermissionResponse::grant(PermissionScope::Once))
499                    .await
500                    .unwrap();
501            }
502        });
503
504        let result = tool.execute(context, input).await;
505
506        assert!(result.is_ok());
507        assert!(result.unwrap().contains("overwrote"));
508        assert_eq!(
509            tokio::fs::read_to_string(&file_path).await.unwrap(),
510            "new content"
511        );
512    }
513
514    #[tokio::test]
515    async fn test_create_parent_directories() {
516        let (registry, mut event_rx) = create_test_registry();
517        let tool = WriteFileTool::new(registry.clone());
518        let temp_dir = TempDir::new().unwrap();
519        let file_path = temp_dir.path().join("nested/dir/test.txt");
520
521        let mut input = HashMap::new();
522        input.insert(
523            "file_path".to_string(),
524            serde_json::Value::String(file_path.to_str().unwrap().to_string()),
525        );
526        input.insert(
527            "content".to_string(),
528            serde_json::Value::String("nested content".to_string()),
529        );
530
531        let context = ToolContext {
532            session_id: 1,
533            tool_use_id: "test-nested".to_string(),
534            turn_id: None,
535        };
536
537        // Grant permission
538        let registry_clone = registry.clone();
539        tokio::spawn(async move {
540            if let Some(ControllerEvent::PermissionRequired { tool_use_id, .. }) =
541                event_rx.recv().await
542            {
543                registry_clone
544                    .respond(&tool_use_id, PermissionResponse::grant(PermissionScope::Once))
545                    .await
546                    .unwrap();
547            }
548        });
549
550        let result = tool.execute(context, input).await;
551
552        assert!(result.is_ok());
553        assert!(file_path.exists());
554        assert!(file_path.parent().unwrap().exists());
555    }
556
557    #[tokio::test]
558    async fn test_relative_path_rejected() {
559        let (registry, _event_rx) = create_test_registry();
560        let tool = WriteFileTool::new(registry);
561
562        let mut input = HashMap::new();
563        input.insert(
564            "file_path".to_string(),
565            serde_json::Value::String("relative/path.txt".to_string()),
566        );
567        input.insert(
568            "content".to_string(),
569            serde_json::Value::String("content".to_string()),
570        );
571
572        let context = ToolContext {
573            session_id: 1,
574            tool_use_id: "test".to_string(),
575            turn_id: None,
576        };
577
578        let result = tool.execute(context, input).await;
579        assert!(result.is_err());
580        assert!(result.unwrap_err().contains("absolute path"));
581    }
582
583    #[tokio::test]
584    async fn test_missing_file_path() {
585        let (registry, _event_rx) = create_test_registry();
586        let tool = WriteFileTool::new(registry);
587
588        let mut input = HashMap::new();
589        input.insert(
590            "content".to_string(),
591            serde_json::Value::String("content".to_string()),
592        );
593
594        let context = ToolContext {
595            session_id: 1,
596            tool_use_id: "test".to_string(),
597            turn_id: None,
598        };
599
600        let result = tool.execute(context, input).await;
601        assert!(result.is_err());
602        assert!(result.unwrap_err().contains("Missing required 'file_path'"));
603    }
604
605    #[tokio::test]
606    async fn test_missing_content() {
607        let (registry, _event_rx) = create_test_registry();
608        let tool = WriteFileTool::new(registry);
609
610        let mut input = HashMap::new();
611        input.insert(
612            "file_path".to_string(),
613            serde_json::Value::String("/tmp/test.txt".to_string()),
614        );
615
616        let context = ToolContext {
617            session_id: 1,
618            tool_use_id: "test".to_string(),
619            turn_id: None,
620        };
621
622        let result = tool.execute(context, input).await;
623        assert!(result.is_err());
624        assert!(result.unwrap_err().contains("Missing required 'content'"));
625    }
626
627    #[test]
628    fn test_compact_summary() {
629        let (registry, _event_rx) = create_test_registry();
630        let tool = WriteFileTool::new(registry);
631
632        let mut input = HashMap::new();
633        input.insert(
634            "file_path".to_string(),
635            serde_json::Value::String("/path/to/file.rs".to_string()),
636        );
637        input.insert(
638            "content".to_string(),
639            serde_json::Value::String("some content here".to_string()),
640        );
641
642        let summary = tool.compact_summary(&input, "Successfully created...");
643        assert_eq!(summary, "[WriteFile: file.rs (17 bytes)]");
644    }
645
646    #[test]
647    fn test_build_permission_request_create() {
648        let request = WriteFileTool::build_permission_request("/path/to/new.txt", 100, false);
649
650        assert_eq!(request.action, "Create file: new.txt");
651        assert_eq!(
652            request.reason,
653            Some("create file with 100 bytes of content".to_string())
654        );
655        assert_eq!(request.resources, vec!["/path/to/new.txt".to_string()]);
656        assert_eq!(request.category, PermissionCategory::FileWrite);
657    }
658
659    #[test]
660    fn test_build_permission_request_overwrite() {
661        let request = WriteFileTool::build_permission_request("/path/to/existing.txt", 500, true);
662
663        assert_eq!(request.action, "Overwrite file: existing.txt");
664        assert_eq!(
665            request.reason,
666            Some("overwrite file with 500 bytes of content".to_string())
667        );
668        assert_eq!(request.resources, vec!["/path/to/existing.txt".to_string()]);
669        assert_eq!(request.category, PermissionCategory::FileWrite);
670    }
671}