collet 0.1.1

Relentless agentic coding orchestrator with zero-drop agent loops
Documentation
//! Git-based surgical hunk editing tool.
//!
//! Applies a unified diff patch to a file using `git apply`.
//! This is more robust than exact string matching for multi-hunk edits.

use serde::Deserialize;

use crate::common::{AgentError, Result};

/// Input for the git_patch tool.
#[derive(Debug, Deserialize)]
pub struct GitPatchInput {
    /// Path to the file to patch (relative to working dir).
    pub path: String,
    /// A unified diff patch to apply.
    /// Must be a valid unified diff format (with --- a/ +++ b/ and @@ hunk headers).
    pub patch: String,
}

/// Tool definition for LLM function calling.
pub fn definition() -> serde_json::Value {
    serde_json::json!({
        "type": "function",
        "function": {
            "name": "git_patch",
            "description": "Apply a unified diff patch to a file. Use this for multi-hunk edits or when you need to modify several non-adjacent sections of a file at once. The patch must be in unified diff format with proper hunk headers (@@ -line,count +line,count @@).",
            "parameters": {
                "type": "object",
                "properties": {
                    "path": {
                        "type": "string",
                        "description": "Path to the file to patch"
                    },
                    "patch": {
                        "type": "string",
                        "description": "Unified diff patch content. Must include --- a/path and +++ b/path headers and @@ hunk headers."
                    }
                },
                "required": ["path", "patch"]
            }
        }
    })
}

/// Execute the git_patch tool.
pub async fn execute(input: GitPatchInput, working_dir: &str) -> Result<String> {
    let resolved_path = resolve_path(&input.path, working_dir);

    // Verify the file exists
    if !std::path::Path::new(&resolved_path).exists() {
        return Err(AgentError::InvalidArgument(format!(
            "File not found: {}",
            input.path
        )));
    }

    // Normalize the patch — ensure it has proper file headers
    let patch = normalize_patch(&input.patch, &input.path);

    // Validate patch format
    if !patch.contains("@@") {
        return Err(AgentError::InvalidArgument(
            "Invalid patch format: missing @@ hunk headers.\n\
             Expected unified diff format like:\n\
             --- a/file.rs\n\
             +++ b/file.rs\n\
             @@ -1,3 +1,4 @@\n\
              context line\n\
             -old line\n\
             +new line\n\
              context line"
                .to_string(),
        ));
    }

    // Try applying with git apply first (most reliable)
    let result = apply_with_git(&patch, working_dir).await;

    match result {
        Ok(msg) => Ok(msg),
        Err(git_err) => {
            // Fall back to manual patch application
            tracing::warn!("git apply failed, trying manual patch: {git_err}");
            match apply_manual(&patch, &resolved_path).await {
                Ok(msg) => Ok(format!("{msg} (manual fallback)")),
                Err(manual_err) => Err(AgentError::InvalidArgument(format!(
                    "Patch failed.\n\
                         git apply: {git_err}\n\
                         manual: {manual_err}\n\n\
                         Tip: Use file_edit for single-location changes, \
                         or ensure your patch has correct line numbers and context."
                ))),
            }
        }
    }
}

/// Apply patch using `git apply`.
async fn apply_with_git(patch: &str, working_dir: &str) -> Result<String> {
    let mut child = tokio::process::Command::new("git")
        .args(["apply", "--verbose", "-"])
        .current_dir(working_dir)
        .stdin(std::process::Stdio::piped())
        .stdout(std::process::Stdio::piped())
        .stderr(std::process::Stdio::piped())
        .spawn()?;

    if let Some(mut stdin) = child.stdin.take() {
        use tokio::io::AsyncWriteExt;
        stdin.write_all(patch.as_bytes()).await?;
        // Drop stdin to close it
        drop(stdin);
    }

    let output = child.wait_with_output().await?;

    if output.status.success() {
        let stderr = String::from_utf8_lossy(&output.stderr);
        let applied_hunks = stderr.lines().count();
        Ok(format!(
            "Patch applied successfully ({applied_hunks} hunks)"
        ))
    } else {
        let stderr = String::from_utf8_lossy(&output.stderr);
        Err(AgentError::Command(stderr.trim().to_string()))
    }
}

/// Manual fallback: parse and apply unified diff hunks directly.
async fn apply_manual(patch: &str, file_path: &str) -> Result<String> {
    let content = tokio::fs::read_to_string(file_path).await?;
    let lines: Vec<&str> = content.lines().collect();

    let hunks = parse_hunks(patch)?;
    if hunks.is_empty() {
        return Err(AgentError::InvalidArgument(
            "No valid hunks found in patch".to_string(),
        ));
    }

    // Apply hunks in reverse order (so line numbers stay valid)
    let mut result_lines: Vec<String> = lines.iter().map(|l| l.to_string()).collect();

    let mut sorted_hunks = hunks;
    sorted_hunks.sort_by(|a, b| b.old_start.cmp(&a.old_start));

    for hunk in &sorted_hunks {
        let start = (hunk.old_start as usize).saturating_sub(1);
        let end = start + hunk.old_count as usize;

        if end > result_lines.len() {
            return Err(AgentError::InvalidArgument(format!(
                "Hunk at line {} extends beyond file (file has {} lines)",
                hunk.old_start,
                result_lines.len()
            )));
        }

        // Verify context lines match
        let mut old_idx = start;
        for line in &hunk.lines {
            match line {
                PatchLine::Context(text) | PatchLine::Remove(text) => {
                    if old_idx < result_lines.len() {
                        let actual = &result_lines[old_idx];
                        if actual.trim() != text.trim() {
                            return Err(AgentError::InvalidArgument(format!(
                                "Context mismatch at line {}: expected '{}', got '{}'",
                                old_idx + 1,
                                text,
                                actual,
                            )));
                        }
                    }
                    old_idx += 1;
                }
                PatchLine::Add(_) => {}
            }
        }

        // Build replacement
        let mut new_lines: Vec<String> = Vec::new();
        for line in &hunk.lines {
            match line {
                PatchLine::Context(text) | PatchLine::Add(text) => {
                    new_lines.push(text.to_string());
                }
                PatchLine::Remove(_) => {}
            }
        }

        // Splice
        result_lines.splice(start..end, new_lines);
    }

    let new_content = result_lines.join("\n");
    // Preserve trailing newline if original had one
    let final_content = if content.ends_with('\n') {
        format!("{new_content}\n")
    } else {
        new_content
    };

    tokio::fs::write(file_path, &final_content).await?;

    Ok(format!(
        "Patch applied manually ({} hunks)",
        sorted_hunks.len()
    ))
}

/// A parsed hunk from a unified diff.
struct Hunk {
    old_start: u32,
    old_count: u32,
    lines: Vec<PatchLine>,
}

enum PatchLine {
    Context(String),
    Add(String),
    Remove(String),
}

/// Parse hunks from a unified diff string.
fn parse_hunks(patch: &str) -> Result<Vec<Hunk>> {
    let mut hunks = Vec::new();
    let mut current_hunk: Option<Hunk> = None;

    for line in patch.lines() {
        if line.starts_with("@@") {
            // Finalize previous hunk
            if let Some(hunk) = current_hunk.take() {
                hunks.push(hunk);
            }

            // Parse @@ -old_start,old_count +new_start,new_count @@
            let (old_start, old_count) = parse_hunk_header(line)?;
            current_hunk = Some(Hunk {
                old_start,
                old_count,
                lines: Vec::new(),
            });
        } else if line.starts_with("---") || line.starts_with("+++") {
            // Skip file headers
            continue;
        } else if let Some(ref mut hunk) = current_hunk {
            if let Some(text) = line.strip_prefix('+') {
                hunk.lines.push(PatchLine::Add(text.to_string()));
            } else if let Some(text) = line.strip_prefix('-') {
                hunk.lines.push(PatchLine::Remove(text.to_string()));
            } else if let Some(text) = line.strip_prefix(' ') {
                hunk.lines.push(PatchLine::Context(text.to_string()));
            } else if line.is_empty() {
                // Empty context line
                hunk.lines.push(PatchLine::Context(String::new()));
            }
        }
    }

    // Don't forget last hunk
    if let Some(hunk) = current_hunk {
        hunks.push(hunk);
    }

    Ok(hunks)
}

/// Parse the @@ header to extract old_start and old_count.
fn parse_hunk_header(line: &str) -> Result<(u32, u32)> {
    // Format: @@ -old_start,old_count +new_start,new_count @@ optional text
    let parts: Vec<&str> = line.split_whitespace().collect();
    if parts.len() < 3 {
        return Err(AgentError::InvalidArgument(format!(
            "Invalid hunk header: {}",
            line
        )));
    }

    let old_range = parts[1]; // -old_start,old_count
    let old_range = old_range.strip_prefix('-').unwrap_or(old_range);

    let (start, count) = if let Some((s, c)) = old_range.split_once(',') {
        (s.parse::<u32>()?, c.parse::<u32>()?)
    } else {
        (old_range.parse::<u32>()?, 1)
    };

    Ok((start, count))
}

/// Normalize a patch to ensure it has proper file headers.
fn normalize_patch(patch: &str, path: &str) -> String {
    // Always rewrite --- a/ and +++ b/ headers to use the caller-supplied
    // (approval-gate-checked) path. This prevents a crafted patch from
    // redirecting git apply to a path outside the working directory.
    let body: String = patch
        .lines()
        .filter(|l| !l.starts_with("--- ") && !l.starts_with("+++ "))
        .collect::<Vec<_>>()
        .join("\n");
    // Preserve trailing newline if the original had one.
    let body = if patch.ends_with('\n') {
        format!("{body}\n")
    } else {
        body
    };
    format!("--- a/{path}\n+++ b/{path}\n{body}")
}

fn resolve_path(path: &str, working_dir: &str) -> String {
    let candidate = if path.starts_with('/') {
        std::path::PathBuf::from(path)
    } else {
        std::path::Path::new(working_dir).join(path)
    };
    crate::agent::approval::normalize_path_lexical(&candidate)
        .to_string_lossy()
        .into_owned()
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_parse_hunk_header() {
        let (start, count) = parse_hunk_header("@@ -10,5 +10,7 @@ fn foo()").unwrap();
        assert_eq!(start, 10);
        assert_eq!(count, 5);
    }

    #[test]
    fn test_parse_hunk_header_single_line() {
        let (start, count) = parse_hunk_header("@@ -5 +5,2 @@").unwrap();
        assert_eq!(start, 5);
        assert_eq!(count, 1);
    }

    #[test]
    fn test_parse_hunks() {
        let patch = "\
--- a/test.rs
+++ b/test.rs
@@ -1,3 +1,4 @@
 line one
+inserted line
 line two
 line three
@@ -10,2 +11,2 @@
-old line ten
+new line ten
 line eleven";

        let hunks = parse_hunks(patch).unwrap();
        assert_eq!(hunks.len(), 2);
        assert_eq!(hunks[0].old_start, 1);
        assert_eq!(hunks[0].old_count, 3);
        assert_eq!(hunks[1].old_start, 10);
        assert_eq!(hunks[1].old_count, 2);
    }

    #[test]
    fn test_normalize_patch_adds_headers() {
        let patch = "@@ -1,3 +1,4 @@\n line 1\n+added\n line 2\n line 3";
        let normalized = normalize_patch(patch, "src/main.rs");
        assert!(normalized.starts_with("--- a/src/main.rs"));
        assert!(normalized.contains("+++ b/src/main.rs"));
    }

    #[test]
    fn test_normalize_patch_keeps_existing_headers() {
        let patch = "--- a/foo.rs\n+++ b/foo.rs\n@@ -1,1 +1,1 @@\n-old\n+new";
        let normalized = normalize_patch(patch, "foo.rs");
        assert_eq!(normalized, patch);
    }

    #[tokio::test]
    async fn test_manual_apply_single_hunk() {
        let dir = tempfile::tempdir().unwrap();
        let file_path = dir.path().join("test.txt");
        std::fs::write(&file_path, "line one\nline two\nline three\nline four\n").unwrap();

        let patch = "\
--- a/test.txt
+++ b/test.txt
@@ -2,2 +2,3 @@
 line two
-line three
+line THREE
+inserted line";

        let result = apply_manual(patch, file_path.to_str().unwrap()).await;
        assert!(result.is_ok(), "apply_manual failed: {:?}", result);

        let content = std::fs::read_to_string(&file_path).unwrap();
        assert!(content.contains("line THREE"));
        assert!(content.contains("inserted line"));
        assert!(!content.contains("line three"));
    }
}