ricecoder_tools/
patch.rs

1//! Patch tool for applying unified diff patches
2//!
3//! Provides functionality to parse and apply unified diff patches with conflict detection.
4
5use crate::error::ToolError;
6use serde::{Deserialize, Serialize};
7use std::path::Path;
8
9/// Input for patch operations
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct PatchInput {
12    /// Path to the file to patch
13    pub file_path: String,
14    /// Unified diff patch content
15    pub patch_content: String,
16}
17
18/// Output from patch operations
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct PatchOutput {
21    /// Whether the patch was successfully applied
22    pub success: bool,
23    /// Number of hunks successfully applied
24    pub applied_hunks: usize,
25    /// Number of hunks that failed to apply
26    pub failed_hunks: usize,
27    /// Details about failed hunks
28    pub failed_hunk_details: Vec<FailedHunkInfo>,
29}
30
31/// Information about a failed hunk
32#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct FailedHunkInfo {
34    /// Hunk number (1-indexed)
35    pub hunk_number: usize,
36    /// Starting line number in the original file
37    pub line_number: usize,
38    /// Error message
39    pub error: String,
40    /// Context lines around the failure
41    pub context: Option<String>,
42}
43
44/// Represents a single hunk in a unified diff
45#[derive(Debug, Clone)]
46struct Hunk {
47    /// Original file starting line number
48    pub orig_start: usize,
49    /// Lines in the hunk (with +/- prefix)
50    pub lines: Vec<String>,
51}
52
53/// Patch tool for applying unified diff patches
54pub struct PatchTool;
55
56impl PatchTool {
57    /// Create a new patch tool
58    pub fn new() -> Self {
59        Self
60    }
61
62    /// Parse a unified diff patch
63    fn parse_patch(patch_content: &str) -> Result<Vec<Hunk>, ToolError> {
64        let mut hunks = Vec::new();
65        let mut lines = patch_content.lines().peekable();
66
67        // Skip file headers (--- and +++ lines)
68        while let Some(line) = lines.peek() {
69            if line.starts_with("---") || line.starts_with("+++") {
70                lines.next();
71            } else if line.starts_with("@@") {
72                break;
73            } else if line.is_empty() || line.starts_with("diff ") || line.starts_with("index ") {
74                lines.next();
75            } else {
76                lines.next();
77            }
78        }
79
80        // Parse hunks
81        while let Some(line) = lines.next() {
82            if !line.starts_with("@@") {
83                continue;
84            }
85
86            // Parse hunk header: @@ -orig_start,orig_count +new_start,new_count @@
87            let hunk_header = line
88                .trim_start_matches("@@")
89                .trim_end_matches("@@")
90                .trim();
91
92            let parts: Vec<&str> = hunk_header.split_whitespace().collect();
93            if parts.len() < 2 {
94                return Err(ToolError::new("INVALID_PATCH", "Invalid hunk header format")
95                    .with_details(format!("Hunk header: {}", line))
96                    .with_suggestion("Ensure patch is in unified diff format"));
97            }
98
99            let (orig_start, _orig_count) = Self::parse_range(parts[0])?;
100            let (_new_start, _new_count) = Self::parse_range(parts[1])?;
101
102            let mut hunk_lines = Vec::new();
103
104            // Read hunk lines
105            while let Some(hunk_line) = lines.peek() {
106                if hunk_line.starts_with("@@") {
107                    break;
108                }
109                if hunk_line.starts_with("\\") {
110                    // Skip "\ No newline at end of file" markers
111                    lines.next();
112                    continue;
113                }
114                if hunk_line.is_empty() || hunk_line.starts_with("-") || hunk_line.starts_with("+")
115                    || hunk_line.starts_with(" ")
116                {
117                    hunk_lines.push(lines.next().unwrap().to_string());
118                } else {
119                    break;
120                }
121            }
122
123            hunks.push(Hunk {
124                orig_start,
125                lines: hunk_lines,
126            });
127        }
128
129        if hunks.is_empty() {
130            return Err(ToolError::new("INVALID_PATCH", "No hunks found in patch")
131                .with_suggestion("Ensure patch contains at least one hunk"));
132        }
133
134        Ok(hunks)
135    }
136
137    /// Parse a range specification (e.g., "-10,5" or "+20,3")
138    fn parse_range(range_spec: &str) -> Result<(usize, usize), ToolError> {
139        let range_spec = range_spec.trim_start_matches('-').trim_start_matches('+');
140        let parts: Vec<&str> = range_spec.split(',').collect();
141
142        match parts.len() {
143            1 => {
144                let start = parts[0]
145                    .parse::<usize>()
146                    .map_err(|_| ToolError::new("INVALID_PATCH", "Invalid line number in hunk header"))?;
147                Ok((start, 1))
148            }
149            2 => {
150                let start = parts[0]
151                    .parse::<usize>()
152                    .map_err(|_| ToolError::new("INVALID_PATCH", "Invalid line number in hunk header"))?;
153                let count = parts[1]
154                    .parse::<usize>()
155                    .map_err(|_| ToolError::new("INVALID_PATCH", "Invalid line count in hunk header"))?;
156                Ok((start, count))
157            }
158            _ => Err(ToolError::new("INVALID_PATCH", "Invalid range specification")
159                .with_details(format!("Range: {}", range_spec))),
160        }
161    }
162
163    /// Apply a single hunk to file lines
164    fn apply_hunk(
165        file_lines: &mut Vec<String>,
166        hunk: &Hunk,
167        hunk_number: usize,
168    ) -> Result<(), FailedHunkInfo> {
169        // Convert to 0-indexed
170        let mut file_idx = if hunk.orig_start > 0 {
171            hunk.orig_start - 1
172        } else {
173            0
174        };
175
176        let mut hunk_idx = 0;
177        let mut lines_to_add = Vec::new();
178
179        // First pass: validate the hunk matches the file
180        let mut temp_file_idx = file_idx;
181        let mut temp_hunk_idx = 0;
182
183        while temp_hunk_idx < hunk.lines.len() {
184            let hunk_line = &hunk.lines[temp_hunk_idx];
185
186            if hunk_line.starts_with('-') {
187                // Line should be removed
188                let expected = &hunk_line[1..];
189                if temp_file_idx >= file_lines.len() {
190                    return Err(FailedHunkInfo {
191                        hunk_number,
192                        line_number: hunk.orig_start,
193                        error: "File is too short for this hunk".to_string(),
194                        context: None,
195                    });
196                }
197
198                if file_lines[temp_file_idx] != expected {
199                    let context = format!(
200                        "Expected: '{}', Found: '{}'",
201                        expected, file_lines[temp_file_idx]
202                    );
203                    return Err(FailedHunkInfo {
204                        hunk_number,
205                        line_number: hunk.orig_start + temp_file_idx - file_idx,
206                        error: "Line content mismatch".to_string(),
207                        context: Some(context),
208                    });
209                }
210                temp_file_idx += 1;
211            } else if hunk_line.starts_with('+') {
212                // Line should be added
213                lines_to_add.push(hunk_line[1..].to_string());
214            } else if hunk_line.starts_with(' ') {
215                // Context line should match
216                let expected = &hunk_line[1..];
217                if temp_file_idx >= file_lines.len() {
218                    return Err(FailedHunkInfo {
219                        hunk_number,
220                        line_number: hunk.orig_start,
221                        error: "File is too short for this hunk".to_string(),
222                        context: None,
223                    });
224                }
225
226                if file_lines[temp_file_idx] != expected {
227                    let context = format!(
228                        "Expected: '{}', Found: '{}'",
229                        expected, file_lines[temp_file_idx]
230                    );
231                    return Err(FailedHunkInfo {
232                        hunk_number,
233                        line_number: hunk.orig_start + temp_file_idx - file_idx,
234                        error: "Context line mismatch".to_string(),
235                        context: Some(context),
236                    });
237                }
238                temp_file_idx += 1;
239            }
240            temp_hunk_idx += 1;
241        }
242
243        // Second pass: apply the hunk
244        while hunk_idx < hunk.lines.len() {
245            let hunk_line = &hunk.lines[hunk_idx];
246
247            if hunk_line.starts_with('-') {
248                // Remove line
249                if file_idx < file_lines.len() {
250                    file_lines.remove(file_idx);
251                }
252            } else if hunk_line.starts_with('+') {
253                // Add line
254                file_lines.insert(file_idx, hunk_line[1..].to_string());
255                file_idx += 1;
256            } else if hunk_line.starts_with(' ') {
257                // Context line
258                file_idx += 1;
259            }
260            hunk_idx += 1;
261        }
262
263        Ok(())
264    }
265
266    /// Apply a patch to a file with timeout enforcement (1 second)
267    pub async fn apply_patch_with_timeout(input: &PatchInput) -> Result<PatchOutput, ToolError> {
268        let timeout_duration = std::time::Duration::from_secs(1);
269        
270        match tokio::time::timeout(timeout_duration, async {
271            Self::apply_patch_internal(input)
272        }).await {
273            Ok(result) => result,
274            Err(_) => {
275                Err(ToolError::new("TIMEOUT", "Patch operation exceeded 1 second timeout")
276                    .with_details(format!("File: {}", input.file_path))
277                    .with_suggestion("Try applying the patch again or check file size"))
278            }
279        }
280    }
281
282    /// Apply a patch to a file (synchronous version)
283    pub fn apply_patch(input: &PatchInput) -> Result<PatchOutput, ToolError> {
284        Self::apply_patch_internal(input)
285    }
286
287    /// Internal patch application logic
288    fn apply_patch_internal(input: &PatchInput) -> Result<PatchOutput, ToolError> {
289        // Parse the patch
290        let hunks = Self::parse_patch(&input.patch_content)?;
291
292        // Read the file
293        let file_path = Path::new(&input.file_path);
294        let file_content = std::fs::read_to_string(file_path)
295            .map_err(|e| {
296                if e.kind() == std::io::ErrorKind::NotFound {
297                    ToolError::new("FILE_NOT_FOUND", format!("File not found: {}", input.file_path))
298                        .with_suggestion("Ensure the file path is correct")
299                } else {
300                    ToolError::from(e)
301                }
302            })?;
303
304        let mut file_lines: Vec<String> = file_content.lines().map(|s| s.to_string()).collect();
305
306        let mut applied_hunks = 0;
307        let mut failed_hunks = 0;
308        let mut failed_hunk_details = Vec::new();
309
310        // Apply each hunk
311        for (idx, hunk) in hunks.iter().enumerate() {
312            match Self::apply_hunk(&mut file_lines, hunk, idx + 1) {
313                Ok(()) => {
314                    applied_hunks += 1;
315                }
316                Err(failed_info) => {
317                    failed_hunks += 1;
318                    failed_hunk_details.push(failed_info);
319                }
320            }
321        }
322
323        // If any hunks failed, don't write the file
324        if failed_hunks > 0 {
325            return Ok(PatchOutput {
326                success: false,
327                applied_hunks,
328                failed_hunks,
329                failed_hunk_details,
330            });
331        }
332
333        // Write the patched file
334        let patched_content = file_lines.join("\n");
335        std::fs::write(file_path, patched_content).map_err(|e| {
336            if e.kind() == std::io::ErrorKind::PermissionDenied {
337                ToolError::new("PERMISSION_DENIED", "Permission denied writing to file")
338                    .with_suggestion("Check file permissions")
339            } else {
340                ToolError::from(e)
341            }
342        })?;
343
344        Ok(PatchOutput {
345            success: true,
346            applied_hunks,
347            failed_hunks: 0,
348            failed_hunk_details: Vec::new(),
349        })
350    }
351}
352
353impl Default for PatchTool {
354    fn default() -> Self {
355        Self::new()
356    }
357}
358
359/// Provider implementation for patch tool
360pub mod provider {
361    use super::*;
362    use crate::provider::Provider;
363    use async_trait::async_trait;
364    use std::sync::Arc;
365    use tracing::{debug, warn};
366
367    /// Built-in patch provider
368    pub struct BuiltinPatchProvider;
369
370    #[async_trait]
371    impl Provider for BuiltinPatchProvider {
372        async fn execute(&self, input: &str) -> Result<String, ToolError> {
373            debug!("Executing patch with built-in provider");
374
375            // Parse input as JSON
376            let patch_input: PatchInput = serde_json::from_str(input).map_err(|e| {
377                ToolError::new("INVALID_INPUT", "Failed to parse patch input")
378                    .with_details(e.to_string())
379                    .with_suggestion("Ensure input is valid JSON with file_path and patch_content")
380            })?;
381
382            // Apply the patch
383            let output = PatchTool::apply_patch(&patch_input)?;
384
385            // Return output as JSON
386            serde_json::to_string(&output).map_err(|e| {
387                ToolError::new("SERIALIZATION_ERROR", "Failed to serialize patch output")
388                    .with_details(e.to_string())
389            })
390        }
391    }
392
393    /// MCP patch provider wrapper
394    pub struct McpPatchProvider {
395        mcp_provider: Arc<dyn Provider>,
396    }
397
398    impl McpPatchProvider {
399        /// Create a new MCP patch provider
400        pub fn new(mcp_provider: Arc<dyn Provider>) -> Self {
401            Self { mcp_provider }
402        }
403    }
404
405    #[async_trait]
406    impl Provider for McpPatchProvider {
407        async fn execute(&self, input: &str) -> Result<String, ToolError> {
408            debug!("Executing patch with MCP provider");
409
410            match self.mcp_provider.execute(input).await {
411                Ok(result) => {
412                    debug!("MCP patch provider succeeded");
413                    Ok(result)
414                }
415                Err(e) => {
416                    warn!("MCP patch provider failed, would fall back to built-in: {}", e);
417                    Err(e)
418                }
419            }
420        }
421    }
422}
423
424#[cfg(test)]
425mod tests {
426    use super::*;
427    use tempfile::NamedTempFile;
428    use std::io::Write;
429
430    #[test]
431    fn test_parse_range_single_line() {
432        let (start, count) = PatchTool::parse_range("-10").unwrap();
433        assert_eq!(start, 10);
434        assert_eq!(count, 1);
435    }
436
437    #[test]
438    fn test_parse_range_multiple_lines() {
439        let (start, count) = PatchTool::parse_range("-10,5").unwrap();
440        assert_eq!(start, 10);
441        assert_eq!(count, 5);
442    }
443
444    #[test]
445    fn test_parse_simple_patch() {
446        let patch = r#"--- a/test.txt
447+++ b/test.txt
448@@ -1,3 +1,3 @@
449 line 1
450-line 2
451+line 2 modified
452 line 3"#;
453
454        let hunks = PatchTool::parse_patch(patch).unwrap();
455        assert_eq!(hunks.len(), 1);
456        assert_eq!(hunks[0].orig_start, 1);
457        assert!(!hunks[0].lines.is_empty());
458    }
459
460    #[test]
461    fn test_apply_simple_patch() {
462        let mut file = NamedTempFile::new().unwrap();
463        writeln!(file, "line 1").unwrap();
464        writeln!(file, "line 2").unwrap();
465        writeln!(file, "line 3").unwrap();
466        file.flush().unwrap();
467
468        let patch = r#"--- a/test.txt
469+++ b/test.txt
470@@ -1,3 +1,3 @@
471 line 1
472-line 2
473+line 2 modified
474 line 3"#;
475
476        let input = PatchInput {
477            file_path: file.path().to_string_lossy().to_string(),
478            patch_content: patch.to_string(),
479        };
480
481        let output = PatchTool::apply_patch(&input).unwrap();
482        assert!(output.success);
483        assert_eq!(output.applied_hunks, 1);
484        assert_eq!(output.failed_hunks, 0);
485
486        // Verify the file was modified
487        let content = std::fs::read_to_string(file.path()).unwrap();
488        assert!(content.contains("line 2 modified"));
489    }
490
491    #[test]
492    fn test_patch_conflict_detection() {
493        let mut file = NamedTempFile::new().unwrap();
494        writeln!(file, "line 1").unwrap();
495        writeln!(file, "different line").unwrap();
496        writeln!(file, "line 3").unwrap();
497        file.flush().unwrap();
498
499        let patch = r#"--- a/test.txt
500+++ b/test.txt
501@@ -1,3 +1,3 @@
502 line 1
503-line 2
504+line 2 modified
505 line 3"#;
506
507        let input = PatchInput {
508            file_path: file.path().to_string_lossy().to_string(),
509            patch_content: patch.to_string(),
510        };
511
512        let output = PatchTool::apply_patch(&input).unwrap();
513        assert!(!output.success);
514        assert_eq!(output.failed_hunks, 1);
515        assert!(!output.failed_hunk_details.is_empty());
516    }
517
518    #[test]
519    fn test_patch_file_not_found() {
520        let patch = r#"--- a/test.txt
521+++ b/test.txt
522@@ -1,3 +1,3 @@
523 line 1
524-line 2
525+line 2 modified
526 line 3"#;
527
528        let input = PatchInput {
529            file_path: "/nonexistent/file.txt".to_string(),
530            patch_content: patch.to_string(),
531        };
532
533        let result = PatchTool::apply_patch(&input);
534        assert!(result.is_err());
535        if let Err(err) = result {
536            assert_eq!(err.code, "FILE_NOT_FOUND");
537        }
538    }
539
540    #[test]
541    fn test_invalid_patch_format() {
542        let patch = "invalid patch content";
543
544        let result = PatchTool::parse_patch(patch);
545        assert!(result.is_err());
546        if let Err(err) = result {
547            assert_eq!(err.code, "INVALID_PATCH");
548        }
549    }
550
551    #[tokio::test]
552    async fn test_builtin_provider() {
553        use crate::patch::provider::BuiltinPatchProvider;
554        use crate::provider::Provider;
555
556        let mut file = NamedTempFile::new().unwrap();
557        writeln!(file, "line 1").unwrap();
558        writeln!(file, "line 2").unwrap();
559        writeln!(file, "line 3").unwrap();
560        file.flush().unwrap();
561
562        let patch = r#"--- a/test.txt
563+++ b/test.txt
564@@ -1,3 +1,3 @@
565 line 1
566-line 2
567+line 2 modified
568 line 3"#;
569
570        let input = PatchInput {
571            file_path: file.path().to_string_lossy().to_string(),
572            patch_content: patch.to_string(),
573        };
574
575        let provider = BuiltinPatchProvider;
576        let input_json = serde_json::to_string(&input).unwrap();
577        let result = provider.execute(&input_json).await.unwrap();
578
579        let output: PatchOutput = serde_json::from_str(&result).unwrap();
580        assert!(output.success);
581        assert_eq!(output.applied_hunks, 1);
582    }
583
584    #[tokio::test]
585    async fn test_patch_timeout_enforcement() {
586        let mut file = NamedTempFile::new().unwrap();
587        writeln!(file, "line 1").unwrap();
588        writeln!(file, "line 2").unwrap();
589        writeln!(file, "line 3").unwrap();
590        file.flush().unwrap();
591
592        let patch = r#"--- a/test.txt
593+++ b/test.txt
594@@ -1,3 +1,3 @@
595 line 1
596-line 2
597+line 2 modified
598 line 3"#;
599
600        let input = PatchInput {
601            file_path: file.path().to_string_lossy().to_string(),
602            patch_content: patch.to_string(),
603        };
604
605        // Test that timeout enforcement works (should complete well within 1 second)
606        let result = PatchTool::apply_patch_with_timeout(&input).await;
607        assert!(result.is_ok());
608        let output = result.unwrap();
609        assert!(output.success);
610    }
611}