llm_coding_tools_serdesai/absolute/
edit.rs1use async_trait::async_trait;
4use llm_coding_tools_core::ToolContext;
5use llm_coding_tools_core::operations::edit_file;
6use llm_coding_tools_core::path::AbsolutePathResolver;
7use llm_coding_tools_core::tool_names;
8use serde::Deserialize;
9use serdes_ai::tools::{
10 RunContext, SchemaBuilder, Tool, ToolDefinition, ToolError, ToolResult, ToolReturn,
11};
12
13use crate::common::edit::error_to_serdes;
14
15#[derive(Debug, Deserialize)]
17struct EditArgs {
18 file_path: String,
20 old_string: String,
22 new_string: String,
24 #[serde(default)]
26 replace_all: bool,
27}
28
29#[derive(Debug, Clone, Default)]
31pub struct EditTool;
32
33impl EditTool {
34 #[inline]
36 pub fn new() -> Self {
37 Self
38 }
39}
40
41#[async_trait]
42impl<Deps: Send + Sync> Tool<Deps> for EditTool {
43 fn definition(&self) -> ToolDefinition {
44 let schema = SchemaBuilder::new()
45 .string("file_path", "Absolute path to the file", true)
46 .string("old_string", "The exact text to find and replace", true)
47 .string("new_string", "The text to replace with", true)
48 .boolean(
49 "replace_all",
50 "Replace all occurrences instead of just the first. Defaults to false.",
51 false,
52 )
53 .build()
54 .expect("schema build should not fail");
55
56 ToolDefinition::new(
57 tool_names::EDIT,
58 "Makes exact string replacements in files. Use replace_all=true to replace all occurrences.",
59 )
60 .with_parameters(schema)
61 }
62
63 async fn call(&self, _ctx: &RunContext<Deps>, args: serde_json::Value) -> ToolResult {
64 let args: EditArgs = serde_json::from_value(args)
65 .map_err(|e| ToolError::validation_error(tool_names::EDIT, None, e.to_string()))?;
66
67 let resolver = AbsolutePathResolver;
68 let result = edit_file(
69 &resolver,
70 &args.file_path,
71 &args.old_string,
72 &args.new_string,
73 args.replace_all,
74 )
75 .await;
76
77 result.map(ToolReturn::text).map_err(error_to_serdes)
78 }
79}
80
81impl ToolContext for EditTool {
82 const NAME: &'static str = tool_names::EDIT;
83
84 fn context(&self) -> &'static str {
85 llm_coding_tools_core::context::EDIT_ABSOLUTE
86 }
87}
88
89#[cfg(test)]
90mod tests {
91 use super::*;
92 use serde_json::json;
93 use serdes_ai::tools::RunContext;
94 use std::io::Write as _;
95 use tempfile::NamedTempFile;
96
97 fn mock_ctx() -> RunContext<()> {
98 RunContext::new((), "test-model")
99 }
100
101 #[tokio::test]
102 async fn edit_success() {
103 let mut file = NamedTempFile::new().unwrap();
104 file.write_all(b"hello world").unwrap();
105 file.flush().unwrap();
106
107 let tool = EditTool::new();
108 let result = tool
109 .call(
110 &mock_ctx(),
111 json!({
112 "file_path": file.path().to_string_lossy(),
113 "old_string": "world",
114 "new_string": "rust"
115 }),
116 )
117 .await
118 .unwrap();
119
120 let text = result.as_text().unwrap();
121 assert!(text.contains("1 occurrence"));
122 assert_eq!(std::fs::read_to_string(file.path()).unwrap(), "hello rust");
123 }
124
125 #[tokio::test]
126 async fn edit_not_found_error() {
127 let mut file = NamedTempFile::new().unwrap();
128 file.write_all(b"hello world").unwrap();
129 file.flush().unwrap();
130
131 let tool = EditTool::new();
132 let result = tool
133 .call(
134 &mock_ctx(),
135 json!({
136 "file_path": file.path().to_string_lossy(),
137 "old_string": "not_found",
138 "new_string": "replacement"
139 }),
140 )
141 .await;
142
143 let err = result.unwrap_err();
144 assert!(matches!(err, ToolError::ValidationFailed { .. }));
145 match err {
147 ToolError::ValidationFailed { errors, .. } => {
148 assert!(!errors.is_empty());
149 assert!(errors[0].message.contains("not found"));
150 }
151 _ => panic!("Expected ValidationFailed"),
152 }
153 }
154
155 #[tokio::test]
156 async fn edit_ambiguous_match_error() {
157 let mut file = NamedTempFile::new().unwrap();
158 file.write_all(b"hello hello hello").unwrap();
159 file.flush().unwrap();
160
161 let tool = EditTool::new();
162 let result = tool
163 .call(
164 &mock_ctx(),
165 json!({
166 "file_path": file.path().to_string_lossy(),
167 "old_string": "hello",
168 "new_string": "world",
169 "replace_all": false
170 }),
171 )
172 .await;
173
174 let err = result.unwrap_err();
175 assert!(matches!(err, ToolError::ValidationFailed { .. }));
176 match err {
178 ToolError::ValidationFailed { errors, .. } => {
179 assert!(!errors.is_empty());
180 assert!(errors[0].message.contains("3 times"));
181 }
182 _ => panic!("Expected ValidationFailed"),
183 }
184 }
185}