llm_coding_tools_rig/allowed/
write.rs

1//! Write file tool using [`AllowedPathResolver`].
2
3use llm_coding_tools_core::operations::write_file;
4use llm_coding_tools_core::path::AllowedPathResolver;
5use llm_coding_tools_core::tool_names;
6use llm_coding_tools_core::{ToolContext, ToolError};
7use rig::completion::ToolDefinition;
8use rig::tool::Tool;
9use schemars::{schema_for, JsonSchema};
10use serde::Deserialize;
11
12/// Arguments for the write tool.
13#[derive(Debug, Clone, Deserialize, JsonSchema)]
14pub struct WriteToolArgs {
15    /// Relative path for the file to write (within allowed directories).
16    pub file_path: String,
17    /// Content to write to the file.
18    pub content: String,
19}
20
21/// Tool for writing content to files within allowed directories.
22#[derive(Debug, Clone)]
23pub struct WriteTool {
24    resolver: AllowedPathResolver,
25}
26
27impl WriteTool {
28    /// Creates a new write tool with a shared resolver.
29    ///
30    /// See [`ReadTool::new`] for usage example.
31    ///
32    /// [`ReadTool::new`]: super::ReadTool::new
33    pub fn new(resolver: AllowedPathResolver) -> Self {
34        Self { resolver }
35    }
36}
37
38impl Tool for WriteTool {
39    const NAME: &'static str = tool_names::WRITE;
40
41    type Error = ToolError;
42    type Args = WriteToolArgs;
43    type Output = String;
44
45    async fn definition(&self, _prompt: String) -> ToolDefinition {
46        ToolDefinition {
47            name: <Self as Tool>::NAME.to_string(),
48            description: "Write content to a file within allowed directories. \
49                          Paths are relative to configured base directories."
50                .to_string(),
51            parameters: serde_json::to_value(schema_for!(WriteToolArgs))
52                .expect("schema generation should not fail"),
53        }
54    }
55
56    async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
57        write_file(&self.resolver, &args.file_path, &args.content).await
58    }
59}
60
61impl ToolContext for WriteTool {
62    const NAME: &'static str = tool_names::WRITE;
63
64    fn context(&self) -> &'static str {
65        llm_coding_tools_core::context::WRITE_ALLOWED
66    }
67}
68
69#[cfg(test)]
70mod tests {
71    use super::*;
72    use tempfile::TempDir;
73
74    #[tokio::test]
75    async fn writes_new_file() {
76        let dir = TempDir::new().unwrap();
77        let resolver = AllowedPathResolver::new([dir.path()]).unwrap();
78        let tool = WriteTool::new(resolver);
79        let result = tool
80            .call(WriteToolArgs {
81                file_path: "new.txt".to_string(),
82                content: "hello".to_string(),
83            })
84            .await
85            .unwrap();
86        assert!(result.contains("5 bytes"));
87        assert!(dir.path().join("new.txt").exists());
88    }
89
90    #[tokio::test]
91    async fn rejects_path_traversal() {
92        let dir = TempDir::new().unwrap();
93        let resolver = AllowedPathResolver::new([dir.path()]).unwrap();
94        let tool = WriteTool::new(resolver);
95        let result = tool
96            .call(WriteToolArgs {
97                file_path: "../../../tmp/escape.txt".to_string(),
98                content: "content".to_string(),
99            })
100            .await;
101        assert!(matches!(result, Err(ToolError::InvalidPath(_))));
102    }
103}