1use crate::error::ToolError;
6use serde::{Deserialize, Serialize};
7use std::path::Path;
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct PatchInput {
12 pub file_path: String,
14 pub patch_content: String,
16}
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct PatchOutput {
21 pub success: bool,
23 pub applied_hunks: usize,
25 pub failed_hunks: usize,
27 pub failed_hunk_details: Vec<FailedHunkInfo>,
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct FailedHunkInfo {
34 pub hunk_number: usize,
36 pub line_number: usize,
38 pub error: String,
40 pub context: Option<String>,
42}
43
44#[derive(Debug, Clone)]
46struct Hunk {
47 pub orig_start: usize,
49 pub lines: Vec<String>,
51}
52
53pub struct PatchTool;
55
56impl PatchTool {
57 pub fn new() -> Self {
59 Self
60 }
61
62 fn parse_patch(patch_content: &str) -> Result<Vec<Hunk>, ToolError> {
64 let mut hunks = Vec::new();
65 let mut lines = patch_content.lines().peekable();
66
67 while let Some(line) = lines.peek() {
69 if line.starts_with("---") || line.starts_with("+++") {
70 lines.next();
71 } else if line.starts_with("@@") {
72 break;
73 } else if line.is_empty() || line.starts_with("diff ") || line.starts_with("index ") {
74 lines.next();
75 } else {
76 lines.next();
77 }
78 }
79
80 while let Some(line) = lines.next() {
82 if !line.starts_with("@@") {
83 continue;
84 }
85
86 let hunk_header = line
88 .trim_start_matches("@@")
89 .trim_end_matches("@@")
90 .trim();
91
92 let parts: Vec<&str> = hunk_header.split_whitespace().collect();
93 if parts.len() < 2 {
94 return Err(ToolError::new("INVALID_PATCH", "Invalid hunk header format")
95 .with_details(format!("Hunk header: {}", line))
96 .with_suggestion("Ensure patch is in unified diff format"));
97 }
98
99 let (orig_start, _orig_count) = Self::parse_range(parts[0])?;
100 let (_new_start, _new_count) = Self::parse_range(parts[1])?;
101
102 let mut hunk_lines = Vec::new();
103
104 while let Some(hunk_line) = lines.peek() {
106 if hunk_line.starts_with("@@") {
107 break;
108 }
109 if hunk_line.starts_with("\\") {
110 lines.next();
112 continue;
113 }
114 if hunk_line.is_empty() || hunk_line.starts_with("-") || hunk_line.starts_with("+")
115 || hunk_line.starts_with(" ")
116 {
117 hunk_lines.push(lines.next().unwrap().to_string());
118 } else {
119 break;
120 }
121 }
122
123 hunks.push(Hunk {
124 orig_start,
125 lines: hunk_lines,
126 });
127 }
128
129 if hunks.is_empty() {
130 return Err(ToolError::new("INVALID_PATCH", "No hunks found in patch")
131 .with_suggestion("Ensure patch contains at least one hunk"));
132 }
133
134 Ok(hunks)
135 }
136
137 fn parse_range(range_spec: &str) -> Result<(usize, usize), ToolError> {
139 let range_spec = range_spec.trim_start_matches('-').trim_start_matches('+');
140 let parts: Vec<&str> = range_spec.split(',').collect();
141
142 match parts.len() {
143 1 => {
144 let start = parts[0]
145 .parse::<usize>()
146 .map_err(|_| ToolError::new("INVALID_PATCH", "Invalid line number in hunk header"))?;
147 Ok((start, 1))
148 }
149 2 => {
150 let start = parts[0]
151 .parse::<usize>()
152 .map_err(|_| ToolError::new("INVALID_PATCH", "Invalid line number in hunk header"))?;
153 let count = parts[1]
154 .parse::<usize>()
155 .map_err(|_| ToolError::new("INVALID_PATCH", "Invalid line count in hunk header"))?;
156 Ok((start, count))
157 }
158 _ => Err(ToolError::new("INVALID_PATCH", "Invalid range specification")
159 .with_details(format!("Range: {}", range_spec))),
160 }
161 }
162
163 fn apply_hunk(
165 file_lines: &mut Vec<String>,
166 hunk: &Hunk,
167 hunk_number: usize,
168 ) -> Result<(), FailedHunkInfo> {
169 let mut file_idx = if hunk.orig_start > 0 {
171 hunk.orig_start - 1
172 } else {
173 0
174 };
175
176 let mut hunk_idx = 0;
177 let mut lines_to_add = Vec::new();
178
179 let mut temp_file_idx = file_idx;
181 let mut temp_hunk_idx = 0;
182
183 while temp_hunk_idx < hunk.lines.len() {
184 let hunk_line = &hunk.lines[temp_hunk_idx];
185
186 if hunk_line.starts_with('-') {
187 let expected = &hunk_line[1..];
189 if temp_file_idx >= file_lines.len() {
190 return Err(FailedHunkInfo {
191 hunk_number,
192 line_number: hunk.orig_start,
193 error: "File is too short for this hunk".to_string(),
194 context: None,
195 });
196 }
197
198 if file_lines[temp_file_idx] != expected {
199 let context = format!(
200 "Expected: '{}', Found: '{}'",
201 expected, file_lines[temp_file_idx]
202 );
203 return Err(FailedHunkInfo {
204 hunk_number,
205 line_number: hunk.orig_start + temp_file_idx - file_idx,
206 error: "Line content mismatch".to_string(),
207 context: Some(context),
208 });
209 }
210 temp_file_idx += 1;
211 } else if hunk_line.starts_with('+') {
212 lines_to_add.push(hunk_line[1..].to_string());
214 } else if hunk_line.starts_with(' ') {
215 let expected = &hunk_line[1..];
217 if temp_file_idx >= file_lines.len() {
218 return Err(FailedHunkInfo {
219 hunk_number,
220 line_number: hunk.orig_start,
221 error: "File is too short for this hunk".to_string(),
222 context: None,
223 });
224 }
225
226 if file_lines[temp_file_idx] != expected {
227 let context = format!(
228 "Expected: '{}', Found: '{}'",
229 expected, file_lines[temp_file_idx]
230 );
231 return Err(FailedHunkInfo {
232 hunk_number,
233 line_number: hunk.orig_start + temp_file_idx - file_idx,
234 error: "Context line mismatch".to_string(),
235 context: Some(context),
236 });
237 }
238 temp_file_idx += 1;
239 }
240 temp_hunk_idx += 1;
241 }
242
243 while hunk_idx < hunk.lines.len() {
245 let hunk_line = &hunk.lines[hunk_idx];
246
247 if hunk_line.starts_with('-') {
248 if file_idx < file_lines.len() {
250 file_lines.remove(file_idx);
251 }
252 } else if hunk_line.starts_with('+') {
253 file_lines.insert(file_idx, hunk_line[1..].to_string());
255 file_idx += 1;
256 } else if hunk_line.starts_with(' ') {
257 file_idx += 1;
259 }
260 hunk_idx += 1;
261 }
262
263 Ok(())
264 }
265
266 pub async fn apply_patch_with_timeout(input: &PatchInput) -> Result<PatchOutput, ToolError> {
268 let timeout_duration = std::time::Duration::from_secs(1);
269
270 match tokio::time::timeout(timeout_duration, async {
271 Self::apply_patch_internal(input)
272 }).await {
273 Ok(result) => result,
274 Err(_) => {
275 Err(ToolError::new("TIMEOUT", "Patch operation exceeded 1 second timeout")
276 .with_details(format!("File: {}", input.file_path))
277 .with_suggestion("Try applying the patch again or check file size"))
278 }
279 }
280 }
281
282 pub fn apply_patch(input: &PatchInput) -> Result<PatchOutput, ToolError> {
284 Self::apply_patch_internal(input)
285 }
286
287 fn apply_patch_internal(input: &PatchInput) -> Result<PatchOutput, ToolError> {
289 let hunks = Self::parse_patch(&input.patch_content)?;
291
292 let file_path = Path::new(&input.file_path);
294 let file_content = std::fs::read_to_string(file_path)
295 .map_err(|e| {
296 if e.kind() == std::io::ErrorKind::NotFound {
297 ToolError::new("FILE_NOT_FOUND", format!("File not found: {}", input.file_path))
298 .with_suggestion("Ensure the file path is correct")
299 } else {
300 ToolError::from(e)
301 }
302 })?;
303
304 let mut file_lines: Vec<String> = file_content.lines().map(|s| s.to_string()).collect();
305
306 let mut applied_hunks = 0;
307 let mut failed_hunks = 0;
308 let mut failed_hunk_details = Vec::new();
309
310 for (idx, hunk) in hunks.iter().enumerate() {
312 match Self::apply_hunk(&mut file_lines, hunk, idx + 1) {
313 Ok(()) => {
314 applied_hunks += 1;
315 }
316 Err(failed_info) => {
317 failed_hunks += 1;
318 failed_hunk_details.push(failed_info);
319 }
320 }
321 }
322
323 if failed_hunks > 0 {
325 return Ok(PatchOutput {
326 success: false,
327 applied_hunks,
328 failed_hunks,
329 failed_hunk_details,
330 });
331 }
332
333 let patched_content = file_lines.join("\n");
335 std::fs::write(file_path, patched_content).map_err(|e| {
336 if e.kind() == std::io::ErrorKind::PermissionDenied {
337 ToolError::new("PERMISSION_DENIED", "Permission denied writing to file")
338 .with_suggestion("Check file permissions")
339 } else {
340 ToolError::from(e)
341 }
342 })?;
343
344 Ok(PatchOutput {
345 success: true,
346 applied_hunks,
347 failed_hunks: 0,
348 failed_hunk_details: Vec::new(),
349 })
350 }
351}
352
353impl Default for PatchTool {
354 fn default() -> Self {
355 Self::new()
356 }
357}
358
359pub mod provider {
361 use super::*;
362 use crate::provider::Provider;
363 use async_trait::async_trait;
364 use std::sync::Arc;
365 use tracing::{debug, warn};
366
367 pub struct BuiltinPatchProvider;
369
370 #[async_trait]
371 impl Provider for BuiltinPatchProvider {
372 async fn execute(&self, input: &str) -> Result<String, ToolError> {
373 debug!("Executing patch with built-in provider");
374
375 let patch_input: PatchInput = serde_json::from_str(input).map_err(|e| {
377 ToolError::new("INVALID_INPUT", "Failed to parse patch input")
378 .with_details(e.to_string())
379 .with_suggestion("Ensure input is valid JSON with file_path and patch_content")
380 })?;
381
382 let output = PatchTool::apply_patch(&patch_input)?;
384
385 serde_json::to_string(&output).map_err(|e| {
387 ToolError::new("SERIALIZATION_ERROR", "Failed to serialize patch output")
388 .with_details(e.to_string())
389 })
390 }
391 }
392
393 pub struct McpPatchProvider {
395 mcp_provider: Arc<dyn Provider>,
396 }
397
398 impl McpPatchProvider {
399 pub fn new(mcp_provider: Arc<dyn Provider>) -> Self {
401 Self { mcp_provider }
402 }
403 }
404
405 #[async_trait]
406 impl Provider for McpPatchProvider {
407 async fn execute(&self, input: &str) -> Result<String, ToolError> {
408 debug!("Executing patch with MCP provider");
409
410 match self.mcp_provider.execute(input).await {
411 Ok(result) => {
412 debug!("MCP patch provider succeeded");
413 Ok(result)
414 }
415 Err(e) => {
416 warn!("MCP patch provider failed, would fall back to built-in: {}", e);
417 Err(e)
418 }
419 }
420 }
421 }
422}
423
424#[cfg(test)]
425mod tests {
426 use super::*;
427 use tempfile::NamedTempFile;
428 use std::io::Write;
429
430 #[test]
431 fn test_parse_range_single_line() {
432 let (start, count) = PatchTool::parse_range("-10").unwrap();
433 assert_eq!(start, 10);
434 assert_eq!(count, 1);
435 }
436
437 #[test]
438 fn test_parse_range_multiple_lines() {
439 let (start, count) = PatchTool::parse_range("-10,5").unwrap();
440 assert_eq!(start, 10);
441 assert_eq!(count, 5);
442 }
443
444 #[test]
445 fn test_parse_simple_patch() {
446 let patch = r#"--- a/test.txt
447+++ b/test.txt
448@@ -1,3 +1,3 @@
449 line 1
450-line 2
451+line 2 modified
452 line 3"#;
453
454 let hunks = PatchTool::parse_patch(patch).unwrap();
455 assert_eq!(hunks.len(), 1);
456 assert_eq!(hunks[0].orig_start, 1);
457 assert!(!hunks[0].lines.is_empty());
458 }
459
460 #[test]
461 fn test_apply_simple_patch() {
462 let mut file = NamedTempFile::new().unwrap();
463 writeln!(file, "line 1").unwrap();
464 writeln!(file, "line 2").unwrap();
465 writeln!(file, "line 3").unwrap();
466 file.flush().unwrap();
467
468 let patch = r#"--- a/test.txt
469+++ b/test.txt
470@@ -1,3 +1,3 @@
471 line 1
472-line 2
473+line 2 modified
474 line 3"#;
475
476 let input = PatchInput {
477 file_path: file.path().to_string_lossy().to_string(),
478 patch_content: patch.to_string(),
479 };
480
481 let output = PatchTool::apply_patch(&input).unwrap();
482 assert!(output.success);
483 assert_eq!(output.applied_hunks, 1);
484 assert_eq!(output.failed_hunks, 0);
485
486 let content = std::fs::read_to_string(file.path()).unwrap();
488 assert!(content.contains("line 2 modified"));
489 }
490
491 #[test]
492 fn test_patch_conflict_detection() {
493 let mut file = NamedTempFile::new().unwrap();
494 writeln!(file, "line 1").unwrap();
495 writeln!(file, "different line").unwrap();
496 writeln!(file, "line 3").unwrap();
497 file.flush().unwrap();
498
499 let patch = r#"--- a/test.txt
500+++ b/test.txt
501@@ -1,3 +1,3 @@
502 line 1
503-line 2
504+line 2 modified
505 line 3"#;
506
507 let input = PatchInput {
508 file_path: file.path().to_string_lossy().to_string(),
509 patch_content: patch.to_string(),
510 };
511
512 let output = PatchTool::apply_patch(&input).unwrap();
513 assert!(!output.success);
514 assert_eq!(output.failed_hunks, 1);
515 assert!(!output.failed_hunk_details.is_empty());
516 }
517
518 #[test]
519 fn test_patch_file_not_found() {
520 let patch = r#"--- a/test.txt
521+++ b/test.txt
522@@ -1,3 +1,3 @@
523 line 1
524-line 2
525+line 2 modified
526 line 3"#;
527
528 let input = PatchInput {
529 file_path: "/nonexistent/file.txt".to_string(),
530 patch_content: patch.to_string(),
531 };
532
533 let result = PatchTool::apply_patch(&input);
534 assert!(result.is_err());
535 if let Err(err) = result {
536 assert_eq!(err.code, "FILE_NOT_FOUND");
537 }
538 }
539
540 #[test]
541 fn test_invalid_patch_format() {
542 let patch = "invalid patch content";
543
544 let result = PatchTool::parse_patch(patch);
545 assert!(result.is_err());
546 if let Err(err) = result {
547 assert_eq!(err.code, "INVALID_PATCH");
548 }
549 }
550
551 #[tokio::test]
552 async fn test_builtin_provider() {
553 use crate::patch::provider::BuiltinPatchProvider;
554 use crate::provider::Provider;
555
556 let mut file = NamedTempFile::new().unwrap();
557 writeln!(file, "line 1").unwrap();
558 writeln!(file, "line 2").unwrap();
559 writeln!(file, "line 3").unwrap();
560 file.flush().unwrap();
561
562 let patch = r#"--- a/test.txt
563+++ b/test.txt
564@@ -1,3 +1,3 @@
565 line 1
566-line 2
567+line 2 modified
568 line 3"#;
569
570 let input = PatchInput {
571 file_path: file.path().to_string_lossy().to_string(),
572 patch_content: patch.to_string(),
573 };
574
575 let provider = BuiltinPatchProvider;
576 let input_json = serde_json::to_string(&input).unwrap();
577 let result = provider.execute(&input_json).await.unwrap();
578
579 let output: PatchOutput = serde_json::from_str(&result).unwrap();
580 assert!(output.success);
581 assert_eq!(output.applied_hunks, 1);
582 }
583
584 #[tokio::test]
585 async fn test_patch_timeout_enforcement() {
586 let mut file = NamedTempFile::new().unwrap();
587 writeln!(file, "line 1").unwrap();
588 writeln!(file, "line 2").unwrap();
589 writeln!(file, "line 3").unwrap();
590 file.flush().unwrap();
591
592 let patch = r#"--- a/test.txt
593+++ b/test.txt
594@@ -1,3 +1,3 @@
595 line 1
596-line 2
597+line 2 modified
598 line 3"#;
599
600 let input = PatchInput {
601 file_path: file.path().to_string_lossy().to_string(),
602 patch_content: patch.to_string(),
603 };
604
605 let result = PatchTool::apply_patch_with_timeout(&input).await;
607 assert!(result.is_ok());
608 let output = result.unwrap();
609 assert!(output.success);
610 }
611}