1use std::collections::HashMap;
8use std::future::Future;
9use std::path::Path;
10use std::pin::Pin;
11use std::sync::Arc;
12
13use tokio::fs;
14
15use super::ask_for_permissions::{PermissionCategory, PermissionRequest};
16use super::permission_registry::PermissionRegistry;
17use super::types::{
18 DisplayConfig, DisplayResult, Executable, ResultContentType, ToolContext, ToolType,
19};
20
21pub const WRITE_FILE_TOOL_NAME: &str = "write_file";
23
24pub const WRITE_FILE_TOOL_DESCRIPTION: &str = r#"Writes content to a file, creating it if it doesn't exist or overwriting if it does.
26
27Usage:
28- The file_path parameter must be an absolute path, not a relative path
29- This tool will overwrite the existing file if there is one at the provided path
30- Parent directories will be created automatically if they don't exist
31- Requires user permission before writing (may be cached for session)
32
33Returns:
34- Success message with bytes written on successful write
35- Error message if permission is denied or the operation fails"#;
36
37pub const WRITE_FILE_TOOL_SCHEMA: &str = r#"{
39 "type": "object",
40 "properties": {
41 "file_path": {
42 "type": "string",
43 "description": "The absolute path to the file to write"
44 },
45 "content": {
46 "type": "string",
47 "description": "The content to write to the file"
48 },
49 "create_directories": {
50 "type": "boolean",
51 "description": "Whether to create parent directories if they don't exist. Defaults to true."
52 }
53 },
54 "required": ["file_path", "content"]
55}"#;
56
57pub struct WriteFileTool {
59 permission_registry: Arc<PermissionRegistry>,
61}
62
63impl WriteFileTool {
64 pub fn new(permission_registry: Arc<PermissionRegistry>) -> Self {
69 Self { permission_registry }
70 }
71
72 fn build_permission_request(
74 file_path: &str,
75 content_len: usize,
76 is_overwrite: bool,
77 ) -> PermissionRequest {
78 let action_verb = if is_overwrite { "Overwrite" } else { "Create" };
79 let filename = Path::new(file_path)
80 .file_name()
81 .and_then(|n| n.to_str())
82 .unwrap_or(file_path);
83
84 PermissionRequest {
85 action: format!("{} file: {}", action_verb, filename),
86 reason: Some(format!(
87 "{} file with {} bytes of content",
88 action_verb.to_lowercase(),
89 content_len
90 )),
91 resources: vec![file_path.to_string()],
92 category: PermissionCategory::FileWrite,
93 }
94 }
95}
96
97impl Executable for WriteFileTool {
98 fn name(&self) -> &str {
99 WRITE_FILE_TOOL_NAME
100 }
101
102 fn description(&self) -> &str {
103 WRITE_FILE_TOOL_DESCRIPTION
104 }
105
106 fn input_schema(&self) -> &str {
107 WRITE_FILE_TOOL_SCHEMA
108 }
109
110 fn tool_type(&self) -> ToolType {
111 ToolType::TextEdit
112 }
113
114 fn execute(
115 &self,
116 context: ToolContext,
117 input: HashMap<String, serde_json::Value>,
118 ) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send>> {
119 let permission_registry = self.permission_registry.clone();
120
121 Box::pin(async move {
122 let file_path = input
126 .get("file_path")
127 .and_then(|v| v.as_str())
128 .ok_or_else(|| "Missing required 'file_path' parameter".to_string())?;
129
130 let content = input
131 .get("content")
132 .and_then(|v| v.as_str())
133 .ok_or_else(|| "Missing required 'content' parameter".to_string())?;
134
135 let create_directories = input
136 .get("create_directories")
137 .and_then(|v| v.as_bool())
138 .unwrap_or(true);
139
140 let path = Path::new(file_path);
141
142 if !path.is_absolute() {
144 return Err(format!(
145 "file_path must be an absolute path, got: {}",
146 file_path
147 ));
148 }
149
150 let is_overwrite = path.exists();
152
153 let permission_request =
157 Self::build_permission_request(file_path, content.len(), is_overwrite);
158
159 let already_granted = permission_registry
163 .is_granted(context.session_id, &permission_request)
164 .await;
165
166 if !already_granted {
167 let response_rx = permission_registry
172 .register(
173 context.tool_use_id.clone(),
174 context.session_id,
175 permission_request,
176 context.turn_id.clone(),
177 )
178 .await
179 .map_err(|e| format!("Failed to request permission: {}", e))?;
180
181 let response = response_rx
185 .await
186 .map_err(|_| "Permission request was cancelled".to_string())?;
187
188 if !response.granted {
192 let reason = response
193 .message
194 .unwrap_or_else(|| "Permission denied by user".to_string());
195 return Err(format!(
196 "Permission denied to write '{}': {}",
197 file_path, reason
198 ));
199 }
200 }
201
202 if create_directories {
206 if let Some(parent) = path.parent() {
207 if !parent.exists() {
208 fs::create_dir_all(parent).await.map_err(|e| {
209 format!("Failed to create parent directories: {}", e)
210 })?;
211 }
212 }
213 }
214
215 let bytes_written = content.len();
219 fs::write(path, content).await.map_err(|e| {
220 format!("Failed to write file '{}': {}", file_path, e)
221 })?;
222
223 let action = if is_overwrite { "overwrote" } else { "created" };
224 Ok(format!(
225 "Successfully {} '{}' ({} bytes)",
226 action, file_path, bytes_written
227 ))
228 })
229 }
230
231 fn display_config(&self) -> DisplayConfig {
232 DisplayConfig {
233 display_name: "Write File".to_string(),
234 display_title: Box::new(|input| {
235 input
236 .get("file_path")
237 .and_then(|v| v.as_str())
238 .map(|p| {
239 Path::new(p)
240 .file_name()
241 .and_then(|n| n.to_str())
242 .unwrap_or(p)
243 .to_string()
244 })
245 .unwrap_or_default()
246 }),
247 display_content: Box::new(|input, result| {
248 let content_preview = input
249 .get("content")
250 .and_then(|v| v.as_str())
251 .map(|c| {
252 let lines: Vec<&str> = c.lines().take(10).collect();
253 if c.lines().count() > 10 {
254 format!("{}...\n[truncated]", lines.join("\n"))
255 } else {
256 lines.join("\n")
257 }
258 })
259 .unwrap_or_else(|| result.to_string());
260
261 DisplayResult {
262 content: content_preview,
263 content_type: ResultContentType::PlainText,
264 is_truncated: input
265 .get("content")
266 .and_then(|v| v.as_str())
267 .map(|c| c.lines().count() > 10)
268 .unwrap_or(false),
269 full_length: input
270 .get("content")
271 .and_then(|v| v.as_str())
272 .map(|c| c.lines().count())
273 .unwrap_or(0),
274 }
275 }),
276 }
277 }
278
279 fn compact_summary(
280 &self,
281 input: &HashMap<String, serde_json::Value>,
282 _result: &str,
283 ) -> String {
284 let filename = input
285 .get("file_path")
286 .and_then(|v| v.as_str())
287 .map(|p| {
288 Path::new(p)
289 .file_name()
290 .and_then(|n| n.to_str())
291 .unwrap_or(p)
292 })
293 .unwrap_or("unknown");
294
295 let bytes = input
296 .get("content")
297 .and_then(|v| v.as_str())
298 .map(|c| c.len())
299 .unwrap_or(0);
300
301 format!("[WriteFile: {} ({} bytes)]", filename, bytes)
302 }
303}
304
305#[cfg(test)]
306mod tests {
307 use super::*;
308 use crate::controller::tools::ask_for_permissions::{PermissionResponse, PermissionScope};
309 use crate::controller::types::ControllerEvent;
310 use tempfile::TempDir;
311 use tokio::sync::mpsc;
312
313 fn create_test_registry() -> (Arc<PermissionRegistry>, mpsc::Receiver<ControllerEvent>) {
315 let (tx, rx) = mpsc::channel(16);
316 let registry = Arc::new(PermissionRegistry::new(tx));
317 (registry, rx)
318 }
319
320 #[tokio::test]
321 async fn test_write_new_file_with_permission_granted() {
322 let (registry, mut event_rx) = create_test_registry();
323 let tool = WriteFileTool::new(registry.clone());
324 let temp_dir = TempDir::new().unwrap();
325 let file_path = temp_dir.path().join("test.txt");
326
327 let mut input = HashMap::new();
328 input.insert(
329 "file_path".to_string(),
330 serde_json::Value::String(file_path.to_str().unwrap().to_string()),
331 );
332 input.insert(
333 "content".to_string(),
334 serde_json::Value::String("Hello, World!".to_string()),
335 );
336
337 let context = ToolContext {
338 session_id: 1,
339 tool_use_id: "test-123".to_string(),
340 turn_id: None,
341 };
342
343 let registry_clone = registry.clone();
345 tokio::spawn(async move {
346 if let Some(ControllerEvent::PermissionRequired { tool_use_id, .. }) =
348 event_rx.recv().await
349 {
350 registry_clone
352 .respond(&tool_use_id, PermissionResponse::grant(PermissionScope::Once))
353 .await
354 .unwrap();
355 }
356 });
357
358 let result = tool.execute(context, input).await;
359
360 assert!(result.is_ok());
361 assert!(file_path.exists());
362 assert_eq!(
363 tokio::fs::read_to_string(&file_path).await.unwrap(),
364 "Hello, World!"
365 );
366 }
367
368 #[tokio::test]
369 async fn test_write_file_permission_denied() {
370 let (registry, mut event_rx) = create_test_registry();
371 let tool = WriteFileTool::new(registry.clone());
372 let temp_dir = TempDir::new().unwrap();
373 let file_path = temp_dir.path().join("test.txt");
374
375 let mut input = HashMap::new();
376 input.insert(
377 "file_path".to_string(),
378 serde_json::Value::String(file_path.to_str().unwrap().to_string()),
379 );
380 input.insert(
381 "content".to_string(),
382 serde_json::Value::String("Hello, World!".to_string()),
383 );
384
385 let context = ToolContext {
386 session_id: 1,
387 tool_use_id: "test-456".to_string(),
388 turn_id: None,
389 };
390
391 let registry_clone = registry.clone();
393 tokio::spawn(async move {
394 if let Some(ControllerEvent::PermissionRequired { tool_use_id, .. }) =
395 event_rx.recv().await
396 {
397 registry_clone
399 .respond(
400 &tool_use_id,
401 PermissionResponse::deny(Some("Not allowed".to_string())),
402 )
403 .await
404 .unwrap();
405 }
406 });
407
408 let result = tool.execute(context, input).await;
409
410 assert!(result.is_err());
411 assert!(result.unwrap_err().contains("Permission denied"));
412 assert!(!file_path.exists());
413 }
414
415 #[tokio::test]
416 async fn test_write_file_session_permission_cached() {
417 let (registry, mut event_rx) = create_test_registry();
418 let tool = WriteFileTool::new(registry.clone());
419 let temp_dir = TempDir::new().unwrap();
420
421 let file_path_1 = temp_dir.path().join("test1.txt");
423 let mut input_1 = HashMap::new();
424 input_1.insert(
425 "file_path".to_string(),
426 serde_json::Value::String(file_path_1.to_str().unwrap().to_string()),
427 );
428 input_1.insert(
429 "content".to_string(),
430 serde_json::Value::String("Content 1".to_string()),
431 );
432
433 let context_1 = ToolContext {
434 session_id: 1,
435 tool_use_id: "test-1".to_string(),
436 turn_id: None,
437 };
438
439 let registry_clone = registry.clone();
441 tokio::spawn(async move {
442 if let Some(ControllerEvent::PermissionRequired { tool_use_id, .. }) =
443 event_rx.recv().await
444 {
445 registry_clone
446 .respond(
447 &tool_use_id,
448 PermissionResponse::grant(PermissionScope::Session),
449 )
450 .await
451 .unwrap();
452 }
453 });
454
455 let result_1 = tool.execute(context_1, input_1).await;
456 assert!(result_1.is_ok());
457 assert!(file_path_1.exists());
458
459 }
464
465 #[tokio::test]
466 async fn test_overwrite_existing_file() {
467 let (registry, mut event_rx) = create_test_registry();
468 let tool = WriteFileTool::new(registry.clone());
469 let temp_dir = TempDir::new().unwrap();
470 let file_path = temp_dir.path().join("existing.txt");
471
472 tokio::fs::write(&file_path, "old content").await.unwrap();
474
475 let mut input = HashMap::new();
476 input.insert(
477 "file_path".to_string(),
478 serde_json::Value::String(file_path.to_str().unwrap().to_string()),
479 );
480 input.insert(
481 "content".to_string(),
482 serde_json::Value::String("new content".to_string()),
483 );
484
485 let context = ToolContext {
486 session_id: 1,
487 tool_use_id: "test-overwrite".to_string(),
488 turn_id: None,
489 };
490
491 let registry_clone = registry.clone();
493 tokio::spawn(async move {
494 if let Some(ControllerEvent::PermissionRequired { tool_use_id, .. }) =
495 event_rx.recv().await
496 {
497 registry_clone
498 .respond(&tool_use_id, PermissionResponse::grant(PermissionScope::Once))
499 .await
500 .unwrap();
501 }
502 });
503
504 let result = tool.execute(context, input).await;
505
506 assert!(result.is_ok());
507 assert!(result.unwrap().contains("overwrote"));
508 assert_eq!(
509 tokio::fs::read_to_string(&file_path).await.unwrap(),
510 "new content"
511 );
512 }
513
514 #[tokio::test]
515 async fn test_create_parent_directories() {
516 let (registry, mut event_rx) = create_test_registry();
517 let tool = WriteFileTool::new(registry.clone());
518 let temp_dir = TempDir::new().unwrap();
519 let file_path = temp_dir.path().join("nested/dir/test.txt");
520
521 let mut input = HashMap::new();
522 input.insert(
523 "file_path".to_string(),
524 serde_json::Value::String(file_path.to_str().unwrap().to_string()),
525 );
526 input.insert(
527 "content".to_string(),
528 serde_json::Value::String("nested content".to_string()),
529 );
530
531 let context = ToolContext {
532 session_id: 1,
533 tool_use_id: "test-nested".to_string(),
534 turn_id: None,
535 };
536
537 let registry_clone = registry.clone();
539 tokio::spawn(async move {
540 if let Some(ControllerEvent::PermissionRequired { tool_use_id, .. }) =
541 event_rx.recv().await
542 {
543 registry_clone
544 .respond(&tool_use_id, PermissionResponse::grant(PermissionScope::Once))
545 .await
546 .unwrap();
547 }
548 });
549
550 let result = tool.execute(context, input).await;
551
552 assert!(result.is_ok());
553 assert!(file_path.exists());
554 assert!(file_path.parent().unwrap().exists());
555 }
556
557 #[tokio::test]
558 async fn test_relative_path_rejected() {
559 let (registry, _event_rx) = create_test_registry();
560 let tool = WriteFileTool::new(registry);
561
562 let mut input = HashMap::new();
563 input.insert(
564 "file_path".to_string(),
565 serde_json::Value::String("relative/path.txt".to_string()),
566 );
567 input.insert(
568 "content".to_string(),
569 serde_json::Value::String("content".to_string()),
570 );
571
572 let context = ToolContext {
573 session_id: 1,
574 tool_use_id: "test".to_string(),
575 turn_id: None,
576 };
577
578 let result = tool.execute(context, input).await;
579 assert!(result.is_err());
580 assert!(result.unwrap_err().contains("absolute path"));
581 }
582
583 #[tokio::test]
584 async fn test_missing_file_path() {
585 let (registry, _event_rx) = create_test_registry();
586 let tool = WriteFileTool::new(registry);
587
588 let mut input = HashMap::new();
589 input.insert(
590 "content".to_string(),
591 serde_json::Value::String("content".to_string()),
592 );
593
594 let context = ToolContext {
595 session_id: 1,
596 tool_use_id: "test".to_string(),
597 turn_id: None,
598 };
599
600 let result = tool.execute(context, input).await;
601 assert!(result.is_err());
602 assert!(result.unwrap_err().contains("Missing required 'file_path'"));
603 }
604
605 #[tokio::test]
606 async fn test_missing_content() {
607 let (registry, _event_rx) = create_test_registry();
608 let tool = WriteFileTool::new(registry);
609
610 let mut input = HashMap::new();
611 input.insert(
612 "file_path".to_string(),
613 serde_json::Value::String("/tmp/test.txt".to_string()),
614 );
615
616 let context = ToolContext {
617 session_id: 1,
618 tool_use_id: "test".to_string(),
619 turn_id: None,
620 };
621
622 let result = tool.execute(context, input).await;
623 assert!(result.is_err());
624 assert!(result.unwrap_err().contains("Missing required 'content'"));
625 }
626
627 #[test]
628 fn test_compact_summary() {
629 let (registry, _event_rx) = create_test_registry();
630 let tool = WriteFileTool::new(registry);
631
632 let mut input = HashMap::new();
633 input.insert(
634 "file_path".to_string(),
635 serde_json::Value::String("/path/to/file.rs".to_string()),
636 );
637 input.insert(
638 "content".to_string(),
639 serde_json::Value::String("some content here".to_string()),
640 );
641
642 let summary = tool.compact_summary(&input, "Successfully created...");
643 assert_eq!(summary, "[WriteFile: file.rs (17 bytes)]");
644 }
645
646 #[test]
647 fn test_build_permission_request_create() {
648 let request = WriteFileTool::build_permission_request("/path/to/new.txt", 100, false);
649
650 assert_eq!(request.action, "Create file: new.txt");
651 assert_eq!(
652 request.reason,
653 Some("create file with 100 bytes of content".to_string())
654 );
655 assert_eq!(request.resources, vec!["/path/to/new.txt".to_string()]);
656 assert_eq!(request.category, PermissionCategory::FileWrite);
657 }
658
659 #[test]
660 fn test_build_permission_request_overwrite() {
661 let request = WriteFileTool::build_permission_request("/path/to/existing.txt", 500, true);
662
663 assert_eq!(request.action, "Overwrite file: existing.txt");
664 assert_eq!(
665 request.reason,
666 Some("overwrite file with 500 bytes of content".to_string())
667 );
668 assert_eq!(request.resources, vec!["/path/to/existing.txt".to_string()]);
669 assert_eq!(request.category, PermissionCategory::FileWrite);
670 }
671}