llm_coding_tools_rig/allowed/
edit.rs

1//! Edit file tool using [`AllowedPathResolver`].
2
3use llm_coding_tools_core::operations::edit_file;
4use llm_coding_tools_core::path::AllowedPathResolver;
5use llm_coding_tools_core::tool_names;
6pub use llm_coding_tools_core::EditError;
7use llm_coding_tools_core::ToolContext;
8use rig::completion::ToolDefinition;
9use rig::tool::Tool;
10use schemars::{schema_for, JsonSchema};
11use serde::Deserialize;
12
13/// Arguments for file editing.
14#[derive(Debug, Clone, Deserialize, JsonSchema)]
15pub struct EditArgs {
16    /// Relative path to the file to modify (within allowed directories).
17    pub file_path: String,
18    /// Exact text to find and replace.
19    pub old_string: String,
20    /// Replacement text.
21    pub new_string: String,
22    /// Replace all occurrences (default false).
23    #[serde(default)]
24    pub replace_all: bool,
25}
26
27/// Tool for making exact string replacements in files within allowed directories.
28#[derive(Debug, Clone)]
29pub struct EditTool {
30    resolver: AllowedPathResolver,
31}
32
33impl EditTool {
34    /// Creates a new edit tool with a shared resolver.
35    ///
36    /// See [`ReadTool::new`](crate::allowed::read::ReadTool::new) for usage example.
37    pub fn new(resolver: AllowedPathResolver) -> Self {
38        Self { resolver }
39    }
40}
41
42impl Tool for EditTool {
43    const NAME: &'static str = tool_names::EDIT;
44
45    type Error = EditError;
46    type Args = EditArgs;
47    type Output = String;
48
49    async fn definition(&self, _prompt: String) -> ToolDefinition {
50        ToolDefinition {
51            name: <Self as Tool>::NAME.to_string(),
52            description: "Make exact string replacements in files within allowed directories. \
53                          Paths are relative to configured base directories."
54                .to_string(),
55            parameters: serde_json::to_value(schema_for!(EditArgs))
56                .expect("EditArgs schema generation should not fail"),
57        }
58    }
59
60    async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
61        edit_file(
62            &self.resolver,
63            &args.file_path,
64            &args.old_string,
65            &args.new_string,
66            args.replace_all,
67        )
68        .await
69    }
70}
71
72impl ToolContext for EditTool {
73    const NAME: &'static str = tool_names::EDIT;
74
75    fn context(&self) -> &'static str {
76        llm_coding_tools_core::context::EDIT_ALLOWED
77    }
78}
79
80#[cfg(test)]
81mod tests {
82    use super::*;
83    use llm_coding_tools_core::ToolError;
84    use tempfile::TempDir;
85
86    #[tokio::test]
87    async fn replaces_single_occurrence() {
88        let dir = TempDir::new().unwrap();
89        std::fs::write(dir.path().join("test.txt"), "hello world").unwrap();
90
91        let resolver = AllowedPathResolver::new([dir.path()]).unwrap();
92        let tool = EditTool::new(resolver);
93        let result = tool
94            .call(EditArgs {
95                file_path: "test.txt".to_string(),
96                old_string: "world".to_string(),
97                new_string: "rust".to_string(),
98                replace_all: false,
99            })
100            .await
101            .unwrap();
102        assert!(result.contains("1 occurrence"));
103    }
104
105    #[tokio::test]
106    async fn rejects_path_traversal() {
107        let dir = TempDir::new().unwrap();
108        let resolver = AllowedPathResolver::new([dir.path()]).unwrap();
109        let tool = EditTool::new(resolver);
110        let result = tool
111            .call(EditArgs {
112                file_path: "../../../etc/passwd".to_string(),
113                old_string: "old".to_string(),
114                new_string: "new".to_string(),
115                replace_all: false,
116            })
117            .await;
118        assert!(matches!(
119            result,
120            Err(EditError::Tool(ToolError::InvalidPath(_)))
121        ));
122    }
123}