use super::traits::{FileTool, Tool};
use crate::utils::vtcodegitignore::should_exclude_file;
use anyhow::{Context, Result, anyhow};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use serde_json::{Value, json};
use std::path::{Path, PathBuf};
use std::process::Stdio;
use std::time::SystemTime;
use tokio::process::Command;
#[derive(Debug, Deserialize)]
pub struct SrgnInput {
pub path: String,
pub scope: Option<String>,
pub replacement: Option<String>,
pub language_scope: Option<String>,
pub action: SrgnAction,
#[serde(default)]
pub literal_string: bool,
#[serde(default)]
pub dry_run: bool,
#[serde(default)]
pub invert: bool,
pub custom_query: Option<String>,
pub custom_query_file: Option<String>,
pub flags: Option<Vec<String>>,
#[serde(default)]
pub fail_any: bool,
#[serde(default)]
pub fail_none: bool,
#[serde(default)]
pub join_language_scopes: bool,
#[serde(default)]
pub hidden: bool,
#[serde(default)]
pub gitignored: bool,
#[serde(default)]
pub sorted: bool,
pub threads: Option<usize>,
#[serde(default)]
pub fail_no_files: bool,
pub german_options: Option<GermanOptions>,
}
#[derive(Debug, Deserialize)]
pub struct GermanOptions {
#[serde(default)]
pub prefer_original: bool,
#[serde(default)]
pub naive: bool,
}
#[derive(Debug, Deserialize, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum SrgnAction {
Replace,
Delete,
Upper,
Lower,
Titlecase,
Normalize,
German,
Symbols,
Squeeze,
}
#[derive(Clone)]
pub struct SrgnTool {
workspace_root: PathBuf,
}
impl SrgnTool {
pub fn new(workspace_root: PathBuf) -> Self {
Self { workspace_root }
}
fn build_command_args(&self, input: &SrgnInput) -> Result<Vec<String>> {
let mut args = Vec::new();
if input.dry_run {
args.push("--dry-run".to_string());
}
if input.invert {
args.push("--invert".to_string());
}
if input.fail_any {
args.push("--fail-any".to_string());
}
if input.fail_none {
args.push("--fail-none".to_string());
}
if input.join_language_scopes {
args.push("--join-language-scopes".to_string());
}
if input.hidden {
args.push("--hidden".to_string());
}
if input.gitignored {
args.push("--gitignored".to_string());
}
if input.sorted {
args.push("--sorted".to_string());
}
if input.fail_no_files {
args.push("--fail-no-files".to_string());
}
if let Some(threads) = input.threads
&& threads > 0
{
args.push("--threads".to_string());
args.push(threads.to_string());
}
if let Some(german_opts) = &input.german_options {
if german_opts.prefer_original {
args.push("--german-prefer-original".to_string());
}
if german_opts.naive {
args.push("--german-naive".to_string());
}
}
args.push("--glob".to_string());
args.push(input.path.clone());
match (
&input.scope,
&input.language_scope,
&input.custom_query,
&input.custom_query_file,
) {
(_, _, _, Some(query_file)) => {
let lang = if let Some(lang_scope) = &input.language_scope {
let parts: Vec<String> = lang_scope
.split_whitespace()
.map(|s| s.to_string())
.collect();
parts.first().unwrap_or(&"rust".to_string()).clone()
} else {
"rust".to_string()
};
let query_flag = match lang.as_str() {
"rust" | "rs" => "--rust-query-file",
"python" | "py" => "--python-query-file",
"javascript" | "js" | "typescript" | "ts" => "--typescript-query-file",
"go" => "--go-query-file",
"c" => "--c-query-file",
"csharp" | "cs" | "c#" => "--csharp-query-file",
"hcl" => "--hcl-query-file",
_ => {
return Err(anyhow!(
"Unsupported language for custom query file: {}",
lang
));
}
};
args.push(query_flag.to_string());
args.push(query_file.clone());
}
(_, _, Some(query), None) => {
let lang = if let Some(lang_scope) = &input.language_scope {
let parts: Vec<String> = lang_scope
.split_whitespace()
.map(|s| s.to_string())
.collect();
parts.first().unwrap_or(&"rust".to_string()).clone()
} else {
"rust".to_string()
};
let query_flag = match lang.as_str() {
"rust" | "rs" => "--rust-query",
"python" | "py" => "--python-query",
"javascript" | "js" | "typescript" | "ts" => "--typescript-query",
"go" => "--go-query",
"c" => "--c-query",
"csharp" | "cs" | "c#" => "--csharp-query",
"hcl" => "--hcl-query",
_ => return Err(anyhow!("Unsupported language for custom query: {}", lang)),
};
args.push(query_flag.to_string());
args.push(query.clone());
}
(_, Some(lang_scope), None, None) => {
let parts: Vec<&str> = lang_scope.split_whitespace().collect();
if parts.len() >= 2 {
let lang = parts[0];
let scope = parts[1];
let lang_flag = match lang {
"rust" | "rs" => "--rust",
"python" | "py" => "--python",
"javascript" | "js" => "--typescript", "typescript" | "ts" => "--typescript",
"go" => "--go",
"c" => "--c",
"csharp" | "cs" | "c#" => "--csharp",
"hcl" => "--hcl",
_ => return Err(anyhow!("Unsupported language: {}", lang)),
};
args.push(lang_flag.to_string());
args.push(scope.to_string());
if parts.len() > 2 {
for part in &parts[2..] {
args.push(part.to_string());
}
}
} else {
return Err(anyhow!(
"Invalid language scope format. Expected 'language scope' or 'language scope~pattern', got: {}",
lang_scope
));
}
}
(Some(scope), None, None, None) => {
if input.literal_string {
args.push("--literal-string".to_string());
}
args.push(scope.clone());
}
(None, None, None, None) => {
args.push(".*".to_string());
}
}
match &input.action {
SrgnAction::Replace => {
if let Some(replacement) = &input.replacement {
args.push("--".to_string());
args.push(replacement.clone());
} else {
return Err(anyhow!("Replacement string required for replace action"));
}
}
SrgnAction::Delete => {
args.push("--delete".to_string());
}
SrgnAction::Upper => {
args.push("--upper".to_string());
}
SrgnAction::Lower => {
args.push("--lower".to_string());
}
SrgnAction::Titlecase => {
args.push("--titlecase".to_string());
}
SrgnAction::Normalize => {
args.push("--normalize".to_string());
}
SrgnAction::German => {
args.push("--german".to_string());
}
SrgnAction::Symbols => {
args.push("--symbols".to_string());
}
SrgnAction::Squeeze => {
args.push("--squeeze".to_string());
}
}
if let Some(flags) = &input.flags {
args.extend(flags.clone());
}
Ok(args)
}
fn validate_path(&self, path: &str) -> Result<PathBuf> {
let full_path = self.workspace_root.join(path);
let canonical =
std::fs::canonicalize(&full_path).with_context(|| format!("Invalid path: {}", path))?;
if !canonical.starts_with(&self.workspace_root) {
return Err(anyhow!("Path '{}' is outside workspace", path));
}
Ok(canonical)
}
fn was_file_modified(&self, path: &Path, before_time: SystemTime) -> Result<bool> {
let metadata = std::fs::metadata(path)?;
let modified_time = metadata.modified()?;
Ok(modified_time > before_time)
}
async fn execute_srgn(&self, args: &[String]) -> Result<String> {
let file_paths: Vec<PathBuf> = args
.iter()
.filter(|arg| arg.contains('.') && !arg.starts_with('-'))
.map(|arg| self.validate_path(arg))
.collect::<Result<Vec<_>>>()?;
let before_times: Vec<SystemTime> = file_paths
.iter()
.map(|path| {
std::fs::metadata(path)
.and_then(|m| m.modified())
.unwrap_or(SystemTime::UNIX_EPOCH)
})
.collect();
let output = Command::new("srgn")
.args(args)
.current_dir(&self.workspace_root)
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.output()
.await
.with_context(|| format!("Failed to execute srgn command with args: {:?}", args))?;
let stdout = String::from_utf8_lossy(&output.stdout).to_string();
let stderr = String::from_utf8_lossy(&output.stderr).to_string();
if !output.status.success() {
return Err(anyhow!(
"srgn command failed with exit code {}: {}",
output.status.code().unwrap_or(-1),
stderr.trim()
));
}
if !args.contains(&"--dry-run".to_string()) && !file_paths.is_empty() {
for (i, path) in file_paths.iter().enumerate() {
if !self.was_file_modified(path, before_times[i])? {
return Err(anyhow!(
"File '{}' was not modified as expected",
path.display()
));
}
}
}
if stdout.is_empty() {
Ok(stderr)
} else if stderr.is_empty() {
Ok(stdout)
} else {
Ok(format!("{}\n{}", stdout.trim(), stderr.trim()))
}
}
fn validate_input(&self, input: &SrgnInput) -> Result<()> {
let path = self.workspace_root.join(&input.path);
if !path.exists() && !input.path.contains('*') && !input.path.contains('?') {
return Err(anyhow!("Path '{}' does not exist", input.path));
}
match &input.action {
SrgnAction::Replace => {
if input.replacement.is_none() {
return Err(anyhow!("Replacement action requires a replacement string"));
}
}
SrgnAction::Delete => {
if input.scope.is_none() && input.language_scope.is_none() {
return Err(anyhow!(
"Delete action requires either a scope pattern or language scope"
));
}
}
_ => {}
}
Ok(())
}
}
#[async_trait]
impl Tool for SrgnTool {
async fn execute(&self, args: Value) -> Result<Value> {
let input: SrgnInput = serde_json::from_value(args)
.with_context(|| "Failed to parse SrgnInput from arguments")?;
self.validate_input(&input)?;
let cmd_args = self.build_command_args(&input)?;
let modified_files: Vec<String> = cmd_args
.iter()
.filter(|arg| arg.contains('.') && !arg.starts_with('-') && !arg.starts_with('*'))
.cloned()
.collect();
let output = self.execute_srgn(&cmd_args).await?;
Ok(json!({
"success": true,
"output": output,
"command": format!("srgn {}", cmd_args.join(" ")),
"dry_run": input.dry_run,
"modified_files": if input.dry_run { Vec::<String>::new() } else { modified_files }
}))
}
fn name(&self) -> &'static str {
"srgn"
}
fn description(&self) -> &'static str {
"Code surgeon tool for precise source code manipulation using srgn. Supports syntax-aware search and replace operations across multiple programming languages."
}
}
#[async_trait]
impl FileTool for SrgnTool {
fn workspace_root(&self) -> &PathBuf {
&self.workspace_root
}
async fn should_exclude(&self, path: &Path) -> bool {
should_exclude_file(path).await
}
}