Skip to main content

agentzero_tools/
apply_patch.rs

1use agentzero_core::{Tool, ToolContext, ToolResult};
2use anyhow::{anyhow, Context};
3use async_trait::async_trait;
4use serde::Deserialize;
5use std::path::{Component, Path, PathBuf};
6use tokio::fs;
7
8const BEGIN_PATCH: &str = "*** Begin Patch";
9const END_PATCH: &str = "*** End Patch";
10const UPDATE_FILE: &str = "*** Update File: ";
11const ADD_FILE: &str = "*** Add File: ";
12const DELETE_FILE: &str = "*** Delete File: ";
13
14#[derive(Debug, Clone)]
15struct PatchFile {
16    path: String,
17    op: PatchOp,
18}
19
20#[derive(Debug, Clone)]
21enum PatchOp {
22    Update(Vec<Hunk>),
23    Add(String),
24    Delete,
25}
26
27#[derive(Debug, Clone)]
28struct Hunk {
29    context_before: Vec<String>,
30    removals: Vec<String>,
31    additions: Vec<String>,
32    #[allow(dead_code)]
33    context_after: Vec<String>,
34}
35
36#[derive(Debug, Deserialize)]
37struct ApplyPatchInput {
38    patch: String,
39    #[serde(default)]
40    dry_run: bool,
41}
42
43#[derive(Debug, Default, Clone, Copy)]
44pub struct ApplyPatchTool;
45
46impl ApplyPatchTool {
47    pub fn validate_patch(&self, patch: &str) -> anyhow::Result<()> {
48        let trimmed = patch.trim();
49        if trimmed.is_empty() {
50            anyhow::bail!("patch must not be empty");
51        }
52        let first = trimmed
53            .lines()
54            .next()
55            .context("patch must include a begin marker")?;
56        if first != BEGIN_PATCH {
57            anyhow::bail!("patch must start with `{BEGIN_PATCH}`");
58        }
59        if !trimmed.lines().any(|line| line == END_PATCH) {
60            anyhow::bail!("patch must end with `{END_PATCH}`");
61        }
62        Ok(())
63    }
64
65    fn parse_patch(patch: &str) -> anyhow::Result<Vec<PatchFile>> {
66        let trimmed = patch.trim();
67        let lines: Vec<&str> = trimmed.lines().collect();
68        if lines.is_empty() || lines[0] != BEGIN_PATCH {
69            anyhow::bail!("patch must start with `{BEGIN_PATCH}`");
70        }
71
72        let mut files = Vec::new();
73        let mut i = 1;
74
75        while i < lines.len() {
76            let line = lines[i];
77
78            if line == END_PATCH {
79                break;
80            }
81
82            if let Some(path) = line.strip_prefix(UPDATE_FILE) {
83                let path = path.trim().to_string();
84                i += 1;
85                let mut hunks = Vec::new();
86
87                while i < lines.len()
88                    && lines[i] != END_PATCH
89                    && !lines[i].starts_with(UPDATE_FILE)
90                    && !lines[i].starts_with(ADD_FILE)
91                    && !lines[i].starts_with(DELETE_FILE)
92                {
93                    if lines[i] == "@@" {
94                        i += 1;
95                        let mut context_before = Vec::new();
96                        let mut removals = Vec::new();
97                        let mut additions = Vec::new();
98                        let mut context_after = Vec::new();
99                        let mut seen_change = false;
100
101                        while i < lines.len()
102                            && lines[i] != "@@"
103                            && lines[i] != END_PATCH
104                            && !lines[i].starts_with(UPDATE_FILE)
105                            && !lines[i].starts_with(ADD_FILE)
106                            && !lines[i].starts_with(DELETE_FILE)
107                        {
108                            if let Some(removed) = lines[i].strip_prefix('-') {
109                                seen_change = true;
110                                removals.push(removed.to_string());
111                            } else if let Some(added) = lines[i].strip_prefix('+') {
112                                seen_change = true;
113                                additions.push(added.to_string());
114                            } else if lines[i].starts_with(' ') || lines[i].is_empty() {
115                                let ctx_line = if lines[i].starts_with(' ') {
116                                    lines[i][1..].to_string()
117                                } else {
118                                    String::new()
119                                };
120                                if seen_change {
121                                    context_after.push(ctx_line);
122                                } else {
123                                    context_before.push(ctx_line);
124                                }
125                            }
126                            i += 1;
127                        }
128
129                        hunks.push(Hunk {
130                            context_before,
131                            removals,
132                            additions,
133                            context_after,
134                        });
135                    } else {
136                        i += 1;
137                    }
138                }
139
140                files.push(PatchFile {
141                    path,
142                    op: PatchOp::Update(hunks),
143                });
144            } else if let Some(path) = line.strip_prefix(ADD_FILE) {
145                let path = path.trim().to_string();
146                i += 1;
147                let mut content = Vec::new();
148
149                while i < lines.len()
150                    && lines[i] != END_PATCH
151                    && !lines[i].starts_with(UPDATE_FILE)
152                    && !lines[i].starts_with(ADD_FILE)
153                    && !lines[i].starts_with(DELETE_FILE)
154                {
155                    if let Some(added) = lines[i].strip_prefix('+') {
156                        content.push(added.to_string());
157                    }
158                    i += 1;
159                }
160
161                files.push(PatchFile {
162                    path,
163                    op: PatchOp::Add(content.join("\n") + "\n"),
164                });
165            } else if let Some(path) = line.strip_prefix(DELETE_FILE) {
166                let path = path.trim().to_string();
167                i += 1;
168                // Skip any remaining lines in this section
169                while i < lines.len()
170                    && lines[i] != END_PATCH
171                    && !lines[i].starts_with(UPDATE_FILE)
172                    && !lines[i].starts_with(ADD_FILE)
173                    && !lines[i].starts_with(DELETE_FILE)
174                {
175                    i += 1;
176                }
177                files.push(PatchFile {
178                    path,
179                    op: PatchOp::Delete,
180                });
181            } else {
182                i += 1;
183            }
184        }
185
186        if files.is_empty() {
187            anyhow::bail!("patch contains no file operations");
188        }
189
190        Ok(files)
191    }
192
193    fn apply_hunks(content: &str, hunks: &[Hunk]) -> anyhow::Result<String> {
194        let mut lines: Vec<String> = content.lines().map(|l| l.to_string()).collect();
195
196        // Apply hunks in reverse order to preserve line numbers.
197        for (hunk_idx, hunk) in hunks.iter().enumerate().rev() {
198            let match_pos = Self::find_hunk_position(&lines, hunk).with_context(|| {
199                format!(
200                    "hunk {} could not be matched against the file content",
201                    hunk_idx + 1
202                )
203            })?;
204
205            let remove_start = match_pos + hunk.context_before.len();
206            let remove_end = remove_start + hunk.removals.len();
207
208            // Verify removal lines match.
209            for (j, removal) in hunk.removals.iter().enumerate() {
210                let line_idx = remove_start + j;
211                if line_idx >= lines.len() || lines[line_idx] != *removal {
212                    let actual = if line_idx < lines.len() {
213                        &lines[line_idx]
214                    } else {
215                        "<past end of file>"
216                    };
217                    anyhow::bail!(
218                        "hunk {} removal mismatch at line {}: expected {:?}, found {:?}",
219                        hunk_idx + 1,
220                        line_idx + 1,
221                        removal,
222                        actual,
223                    );
224                }
225            }
226
227            // Replace: remove old lines and insert new ones.
228            lines.splice(remove_start..remove_end, hunk.additions.iter().cloned());
229        }
230
231        let mut result = lines.join("\n");
232        // Preserve trailing newline if original had one.
233        if content.ends_with('\n') && !result.ends_with('\n') {
234            result.push('\n');
235        }
236        Ok(result)
237    }
238
239    fn find_hunk_position(lines: &[String], hunk: &Hunk) -> anyhow::Result<usize> {
240        if hunk.context_before.is_empty() && hunk.removals.is_empty() {
241            // No context or removals — insert at the end.
242            return Ok(lines.len());
243        }
244
245        let match_lines: Vec<&str> = hunk
246            .context_before
247            .iter()
248            .chain(hunk.removals.iter())
249            .map(|s| s.as_str())
250            .collect();
251
252        if match_lines.is_empty() {
253            return Ok(0);
254        }
255
256        for start in 0..=lines.len().saturating_sub(match_lines.len()) {
257            let matched = match_lines
258                .iter()
259                .enumerate()
260                .all(|(j, expected)| start + j < lines.len() && lines[start + j] == *expected);
261            if matched {
262                return Ok(start);
263            }
264        }
265
266        anyhow::bail!("could not locate hunk context in file")
267    }
268
269    fn resolve_path(
270        input_path: &str,
271        workspace_root: &str,
272        allowed_root: &Path,
273    ) -> anyhow::Result<PathBuf> {
274        if input_path.trim().is_empty() {
275            return Err(anyhow!("file path is required"));
276        }
277        let relative = Path::new(input_path);
278        if relative.is_absolute() {
279            return Err(anyhow!("absolute paths are not allowed"));
280        }
281        if relative
282            .components()
283            .any(|c| matches!(c, Component::ParentDir))
284        {
285            return Err(anyhow!("path traversal is not allowed"));
286        }
287
288        let joined = Path::new(workspace_root).join(relative);
289        let file_name = joined
290            .file_name()
291            .ok_or_else(|| anyhow!("path must target a file"))?
292            .to_os_string();
293        let parent = joined
294            .parent()
295            .ok_or_else(|| anyhow!("path must have a parent directory"))?;
296
297        // For new files, the parent might not exist yet — create it.
298        if !parent.exists() {
299            std::fs::create_dir_all(parent)
300                .with_context(|| format!("failed to create parent directory for {input_path}"))?;
301        }
302
303        let canonical_parent = parent
304            .canonicalize()
305            .with_context(|| format!("unable to resolve path parent: {input_path}"))?;
306        let canonical_allowed_root = allowed_root
307            .canonicalize()
308            .context("unable to resolve allowed root")?;
309        if !canonical_parent.starts_with(&canonical_allowed_root) {
310            return Err(anyhow!("path is outside allowed root"));
311        }
312        Ok(canonical_parent.join(file_name))
313    }
314}
315
316#[async_trait]
317impl Tool for ApplyPatchTool {
318    fn name(&self) -> &'static str {
319        "apply_patch"
320    }
321
322    fn description(&self) -> &'static str {
323        "Apply a unified patch to one or more files. Supports update, add, and delete operations with context-based matching and dry-run mode."
324    }
325
326    fn input_schema(&self) -> Option<serde_json::Value> {
327        Some(serde_json::json!({
328            "type": "object",
329            "properties": {
330                "patch": {
331                    "type": "string",
332                    "description": "The patch text in unified diff format"
333                },
334                "dry_run": {
335                    "type": "boolean",
336                    "description": "If true, show what would change without modifying files"
337                }
338            },
339            "required": ["patch"]
340        }))
341    }
342
343    async fn execute(&self, input: &str, ctx: &ToolContext) -> anyhow::Result<ToolResult> {
344        let request: ApplyPatchInput = serde_json::from_str(input)
345            .context("apply_patch expects JSON: {\"patch\": \"...\", \"dry_run\": false}")?;
346
347        self.validate_patch(&request.patch)?;
348        let patch_files = Self::parse_patch(&request.patch)?;
349
350        let allowed_root = PathBuf::from(&ctx.workspace_root);
351        let mut results = Vec::new();
352
353        for pf in &patch_files {
354            let dest = Self::resolve_path(&pf.path, &ctx.workspace_root, &allowed_root)?;
355
356            // B7: Hard-link guard.
357            if dest.exists() {
358                crate::autonomy::AutonomyPolicy::check_hard_links(&dest.to_string_lossy())?;
359            }
360
361            // B7: Sensitive file detection.
362            if !ctx.allow_sensitive_file_writes
363                && crate::autonomy::is_sensitive_path(&dest.to_string_lossy())
364            {
365                return Err(anyhow!(
366                    "refusing to patch sensitive file: {}",
367                    dest.display()
368                ));
369            }
370
371            match &pf.op {
372                PatchOp::Update(hunks) => {
373                    let content = fs::read_to_string(&dest)
374                        .await
375                        .with_context(|| format!("failed to read file: {}", pf.path))?;
376                    let updated = Self::apply_hunks(&content, hunks)?;
377
378                    if request.dry_run {
379                        results.push(format!("update {} (dry_run)", pf.path));
380                    } else {
381                        fs::write(&dest, &updated)
382                            .await
383                            .with_context(|| format!("failed to write file: {}", pf.path))?;
384                        results.push(format!("updated {}", pf.path));
385                    }
386                }
387                PatchOp::Add(content) => {
388                    if request.dry_run {
389                        results.push(format!(
390                            "add {} ({} bytes, dry_run)",
391                            pf.path,
392                            content.len()
393                        ));
394                    } else {
395                        fs::write(&dest, content)
396                            .await
397                            .with_context(|| format!("failed to create file: {}", pf.path))?;
398                        results.push(format!("added {}", pf.path));
399                    }
400                }
401                PatchOp::Delete => {
402                    if request.dry_run {
403                        results.push(format!("delete {} (dry_run)", pf.path));
404                    } else {
405                        fs::remove_file(&dest)
406                            .await
407                            .with_context(|| format!("failed to delete file: {}", pf.path))?;
408                        results.push(format!("deleted {}", pf.path));
409                    }
410                }
411            }
412        }
413
414        Ok(ToolResult {
415            output: results.join("\n"),
416        })
417    }
418}
419
420#[cfg(test)]
421mod tests {
422    use super::ApplyPatchTool;
423    use agentzero_core::{Tool, ToolContext};
424    use std::fs;
425    use std::path::PathBuf;
426    use std::sync::atomic::{AtomicU64, Ordering};
427    use std::time::{SystemTime, UNIX_EPOCH};
428
429    static TEMP_COUNTER: AtomicU64 = AtomicU64::new(0);
430
431    fn temp_dir() -> PathBuf {
432        let nanos = SystemTime::now()
433            .duration_since(UNIX_EPOCH)
434            .expect("clock should be after unix epoch")
435            .as_nanos();
436        let seq = TEMP_COUNTER.fetch_add(1, Ordering::Relaxed);
437        let dir = std::env::temp_dir().join(format!(
438            "agentzero-apply-patch-{}-{nanos}-{seq}",
439            std::process::id()
440        ));
441        fs::create_dir_all(&dir).expect("temp dir should be created");
442        dir
443    }
444
445    #[test]
446    fn validate_patch_accepts_basic_envelope_success_path() {
447        let tool = ApplyPatchTool;
448        let patch = "*** Begin Patch\n*** Update File: test.txt\n@@\n-old\n+new\n*** End Patch\n";
449        tool.validate_patch(patch)
450            .expect("well-formed patch should validate");
451    }
452
453    #[test]
454    fn validate_patch_rejects_missing_begin_marker_negative_path() {
455        let tool = ApplyPatchTool;
456        let err = tool
457            .validate_patch("*** Update File: test.txt\n*** End Patch\n")
458            .expect_err("missing begin marker should fail");
459        assert!(err.to_string().contains("patch must start with"));
460    }
461
462    #[tokio::test]
463    async fn apply_patch_single_file_single_hunk() {
464        let dir = temp_dir();
465        fs::write(dir.join("hello.txt"), "line1\nline2\nline3\n").unwrap();
466
467        let patch = r#"{"patch": "*** Begin Patch\n*** Update File: hello.txt\n@@\n line1\n-line2\n+line2_modified\n line3\n*** End Patch"}"#;
468        let tool = ApplyPatchTool;
469        let result = tool
470            .execute(patch, &ToolContext::new(dir.to_string_lossy().to_string()))
471            .await
472            .expect("patch should apply");
473        assert!(result.output.contains("updated hello.txt"));
474
475        let content = fs::read_to_string(dir.join("hello.txt")).unwrap();
476        assert!(content.contains("line2_modified"));
477        assert!(!content.contains("\nline2\n"));
478        fs::remove_dir_all(dir).ok();
479    }
480
481    #[tokio::test]
482    async fn apply_patch_dry_run_does_not_modify() {
483        let dir = temp_dir();
484        fs::write(dir.join("hello.txt"), "line1\nline2\n").unwrap();
485
486        let patch = r#"{"patch": "*** Begin Patch\n*** Update File: hello.txt\n@@\n-line2\n+changed\n*** End Patch", "dry_run": true}"#;
487        let tool = ApplyPatchTool;
488        let result = tool
489            .execute(patch, &ToolContext::new(dir.to_string_lossy().to_string()))
490            .await
491            .expect("dry_run should succeed");
492        assert!(result.output.contains("dry_run"));
493
494        let content = fs::read_to_string(dir.join("hello.txt")).unwrap();
495        assert!(content.contains("line2"));
496        fs::remove_dir_all(dir).ok();
497    }
498
499    #[tokio::test]
500    async fn apply_patch_add_file() {
501        let dir = temp_dir();
502
503        let patch = r#"{"patch": "*** Begin Patch\n*** Add File: new.txt\n+hello world\n+second line\n*** End Patch"}"#;
504        let tool = ApplyPatchTool;
505        let result = tool
506            .execute(patch, &ToolContext::new(dir.to_string_lossy().to_string()))
507            .await
508            .expect("add file should succeed");
509        assert!(result.output.contains("added new.txt"));
510
511        let content = fs::read_to_string(dir.join("new.txt")).unwrap();
512        assert!(content.contains("hello world"));
513        fs::remove_dir_all(dir).ok();
514    }
515
516    #[tokio::test]
517    async fn apply_patch_delete_file() {
518        let dir = temp_dir();
519        fs::write(dir.join("doomed.txt"), "goodbye").unwrap();
520
521        let patch = r#"{"patch": "*** Begin Patch\n*** Delete File: doomed.txt\n*** End Patch"}"#;
522        let tool = ApplyPatchTool;
523        let result = tool
524            .execute(patch, &ToolContext::new(dir.to_string_lossy().to_string()))
525            .await
526            .expect("delete should succeed");
527        assert!(result.output.contains("deleted doomed.txt"));
528        assert!(!dir.join("doomed.txt").exists());
529        fs::remove_dir_all(dir).ok();
530    }
531
532    #[tokio::test]
533    async fn apply_patch_rejects_path_traversal_negative_path() {
534        let dir = temp_dir();
535
536        let patch =
537            r#"{"patch": "*** Begin Patch\n*** Add File: ../escape.txt\n+evil\n*** End Patch"}"#;
538        let tool = ApplyPatchTool;
539        let err = tool
540            .execute(patch, &ToolContext::new(dir.to_string_lossy().to_string()))
541            .await
542            .expect_err("path traversal should be denied");
543        assert!(err.to_string().contains("path traversal"));
544        fs::remove_dir_all(dir).ok();
545    }
546
547    #[tokio::test]
548    async fn apply_patch_rejects_missing_context_negative_path() {
549        let dir = temp_dir();
550        fs::write(dir.join("hello.txt"), "aaa\nbbb\nccc\n").unwrap();
551
552        let patch = r#"{"patch": "*** Begin Patch\n*** Update File: hello.txt\n@@\n nonexistent_context\n-bbb\n+replaced\n*** End Patch"}"#;
553        let tool = ApplyPatchTool;
554        let err = tool
555            .execute(patch, &ToolContext::new(dir.to_string_lossy().to_string()))
556            .await
557            .expect_err("missing context should fail");
558        assert!(err.to_string().contains("could not be matched"));
559        fs::remove_dir_all(dir).ok();
560    }
561
562    #[tokio::test]
563    async fn apply_patch_rejects_sensitive_file_negative_path() {
564        let dir = temp_dir();
565
566        let patch = r#"{"patch": "*** Begin Patch\n*** Add File: .env\n+SECRET=x\n*** End Patch"}"#;
567        let tool = ApplyPatchTool;
568        let err = tool
569            .execute(patch, &ToolContext::new(dir.to_string_lossy().to_string()))
570            .await
571            .expect_err("sensitive file should be blocked");
572        assert!(err.to_string().contains("refusing to patch sensitive file"));
573        fs::remove_dir_all(dir).ok();
574    }
575
576    #[tokio::test]
577    async fn apply_patch_multi_file() {
578        let dir = temp_dir();
579        fs::write(dir.join("a.txt"), "alpha\nbeta\n").unwrap();
580        fs::write(dir.join("b.txt"), "one\ntwo\n").unwrap();
581
582        let patch = r#"{"patch": "*** Begin Patch\n*** Update File: a.txt\n@@\n-beta\n+BETA\n*** Update File: b.txt\n@@\n-two\n+TWO\n*** End Patch"}"#;
583        let tool = ApplyPatchTool;
584        let result = tool
585            .execute(patch, &ToolContext::new(dir.to_string_lossy().to_string()))
586            .await
587            .expect("multi-file patch should apply");
588        assert!(result.output.contains("updated a.txt"));
589        assert!(result.output.contains("updated b.txt"));
590
591        assert!(fs::read_to_string(dir.join("a.txt"))
592            .unwrap()
593            .contains("BETA"));
594        assert!(fs::read_to_string(dir.join("b.txt"))
595            .unwrap()
596            .contains("TWO"));
597        fs::remove_dir_all(dir).ok();
598    }
599}