kota 0.1.3

A lightweight, highly extensible ai code agent, built in Rust.
Documentation
use super::FileToolError;
use colored::*;
use patch_apply::{apply, Patch};
use rig::{completion::ToolDefinition, tool::Tool};
use serde::{Deserialize, Serialize};
use std::fs;
use std::path::Path;

#[derive(Deserialize)]
pub struct EditFileArgs {
    pub file_path: String,
    pub patch: String,
}

#[derive(Serialize, Debug)]
pub struct EditFileOutput {
    pub file_path: String,
    pub lines_added: usize,
    pub lines_removed: usize,
    pub success: bool,
    pub message: String,
}

#[derive(Deserialize, Serialize, Default)]
pub struct EditFileTool;

impl Tool for EditFileTool {
    const NAME: &'static str = "edit_file";

    type Error = FileToolError;
    type Args = EditFileArgs;
    type Output = EditFileOutput;

    async fn definition(&self, _prompt: String) -> ToolDefinition {
        ToolDefinition {
            name: "edit_file".to_string(),
            description: "Apply a unified diff patch to a file. This is efficient for making small, targeted changes to existing files without rewriting the entire content.".to_string(),
            parameters: serde_json::json!({
                "type": "object",
                "properties": {
                    "file_path": {
                        "type": "string",
                        "description": "The path to the file to edit (relative or absolute). The file must exist. Examples: 'src/main.rs', 'README.md'"
                    },
                    "patch": {
                        "type": "string",
                        "description": "A unified diff patch string in standard format. CRITICAL REQUIREMENTS:\n1. MUST start with '--- a/filename' header line\n2. MUST have '+++ b/filename' as second line\n3. MUST have hunk header: '@@ -old_start,old_count +new_start,new_count @@'\n4. Context lines MUST start with ' ' (space)\n5. Removed lines MUST start with '-'\n6. Added lines MUST start with '+'\n7. EVERY line (including the last) MUST end with '\\n'\n\nExample:\n--- a/main.rs\n+++ b/main.rs\n@@ -1,3 +1,4 @@\n fn main() {\n+    println!(\"Hello, world!\");\n     // existing code\n }\n"
                    }
                },
                "required": ["file_path", "patch"]
            })
        }
    }

    async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
        let file_path = &args.file_path;
        let patch_str = &args.patch;
        let path = Path::new(file_path);

        // Check if file exists
        if !path.exists() {
            return Err(FileToolError::FileNotFound(file_path.clone()));
        }

        // Check if it's actually a file (not a directory)
        if !path.is_file() {
            return Err(FileToolError::NotAFile(file_path.clone()));
        }

        // Read the current file content
        let current_content = fs::read_to_string(file_path)?;

        // Ensure patch_str ends with a newline
        let patch_str_normalized = if !patch_str.ends_with('\n') {
            format!("{}\n", patch_str)
        } else {
            patch_str.to_string()
        };

        // Parse the patch
        let patch = Patch::from_single(&patch_str_normalized).map_err(|e| {
            FileToolError::Io(std::io::Error::new(
                std::io::ErrorKind::InvalidData,
                format!("Failed to parse patch: {}", e),
            ))
        })?;

        // Apply the patch using patch_apply::apply
        let patched_content = apply(current_content, patch);

        // Calculate statistics
        let original_lines: Vec<&str> = args.patch.lines().collect();
        let mut lines_added = 0usize;
        let mut lines_removed = 0usize;

        for line in original_lines {
            if line.starts_with('+') && !line.starts_with("+++") {
                lines_added += 1;
            } else if line.starts_with('-') && !line.starts_with("---") {
                lines_removed += 1;
            }
        }

        // Write the modified content back to the file
        match fs::write(file_path, &patched_content) {
            Ok(()) => Ok(EditFileOutput {
                file_path: file_path.clone(),
                lines_added,
                lines_removed,
                success: true,
                message: format!(
                    "Successfully applied patch to '{}': +{} lines, -{} lines",
                    file_path, lines_added, lines_removed
                ),
            }),
            Err(e) => match e.kind() {
                std::io::ErrorKind::PermissionDenied => {
                    Err(FileToolError::PermissionDenied(file_path.clone()))
                }
                _ => Err(FileToolError::Io(e)),
            },
        }
    }
}

#[derive(Deserialize, Serialize, Default)]
pub struct WrappedEditFileTool {
    inner: EditFileTool,
}

impl WrappedEditFileTool {
    pub fn new() -> Self {
        Self {
            inner: EditFileTool,
        }
    }
}

impl Tool for WrappedEditFileTool {
    const NAME: &'static str = "edit_file";

    type Error = FileToolError;
    type Args = <EditFileTool as Tool>::Args;
    type Output = <EditFileTool as Tool>::Output;

    async fn definition(&self, prompt: String) -> ToolDefinition {
        self.inner.definition(prompt).await
    }

    async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
        println!("\n{} Edit({})", "●".bright_green(), args.file_path);

        let result = self.inner.call(args).await;

        match &result {
            Ok(output) => {
                println!(
                    "  └─ {} (+{} lines, -{} lines)",
                    format!("Patched '{}'", output.file_path).dimmed(),
                    output.lines_added.to_string().green(),
                    output.lines_removed.to_string().red()
                );
            }
            Err(e) => {
                println!("  └─ {}", format!("Error: {}", e).red());
            }
        }
        println!();
        result
    }
}