llm_coding_tools_rig/allowed/
edit.rs1use llm_coding_tools_core::operations::edit_file;
4use llm_coding_tools_core::path::AllowedPathResolver;
5use llm_coding_tools_core::tool_names;
6pub use llm_coding_tools_core::EditError;
7use llm_coding_tools_core::ToolContext;
8use rig::completion::ToolDefinition;
9use rig::tool::Tool;
10use schemars::{schema_for, JsonSchema};
11use serde::Deserialize;
12
13#[derive(Debug, Clone, Deserialize, JsonSchema)]
15pub struct EditArgs {
16 pub file_path: String,
18 pub old_string: String,
20 pub new_string: String,
22 #[serde(default)]
24 pub replace_all: bool,
25}
26
27#[derive(Debug, Clone)]
29pub struct EditTool {
30 resolver: AllowedPathResolver,
31}
32
33impl EditTool {
34 pub fn new(resolver: AllowedPathResolver) -> Self {
38 Self { resolver }
39 }
40}
41
42impl Tool for EditTool {
43 const NAME: &'static str = tool_names::EDIT;
44
45 type Error = EditError;
46 type Args = EditArgs;
47 type Output = String;
48
49 async fn definition(&self, _prompt: String) -> ToolDefinition {
50 ToolDefinition {
51 name: <Self as Tool>::NAME.to_string(),
52 description: "Make exact string replacements in files within allowed directories. \
53 Paths are relative to configured base directories."
54 .to_string(),
55 parameters: serde_json::to_value(schema_for!(EditArgs))
56 .expect("EditArgs schema generation should not fail"),
57 }
58 }
59
60 async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
61 edit_file(
62 &self.resolver,
63 &args.file_path,
64 &args.old_string,
65 &args.new_string,
66 args.replace_all,
67 )
68 .await
69 }
70}
71
72impl ToolContext for EditTool {
73 const NAME: &'static str = tool_names::EDIT;
74
75 fn context(&self) -> &'static str {
76 llm_coding_tools_core::context::EDIT_ALLOWED
77 }
78}
79
80#[cfg(test)]
81mod tests {
82 use super::*;
83 use llm_coding_tools_core::ToolError;
84 use tempfile::TempDir;
85
86 #[tokio::test]
87 async fn replaces_single_occurrence() {
88 let dir = TempDir::new().unwrap();
89 std::fs::write(dir.path().join("test.txt"), "hello world").unwrap();
90
91 let resolver = AllowedPathResolver::new([dir.path()]).unwrap();
92 let tool = EditTool::new(resolver);
93 let result = tool
94 .call(EditArgs {
95 file_path: "test.txt".to_string(),
96 old_string: "world".to_string(),
97 new_string: "rust".to_string(),
98 replace_all: false,
99 })
100 .await
101 .unwrap();
102 assert!(result.contains("1 occurrence"));
103 }
104
105 #[tokio::test]
106 async fn rejects_path_traversal() {
107 let dir = TempDir::new().unwrap();
108 let resolver = AllowedPathResolver::new([dir.path()]).unwrap();
109 let tool = EditTool::new(resolver);
110 let result = tool
111 .call(EditArgs {
112 file_path: "../../../etc/passwd".to_string(),
113 old_string: "old".to_string(),
114 new_string: "new".to_string(),
115 replace_all: false,
116 })
117 .await;
118 assert!(matches!(
119 result,
120 Err(EditError::Tool(ToolError::InvalidPath(_)))
121 ));
122 }
123}