use std::path::{Path, PathBuf};
use ast_grep_core::tree_sitter::LanguageExt;
use crate::ast_grep_lang::AstGrepLang;
use crate::context::AppContext;
use crate::edit::dry_run_diff;
use crate::protocol::{RawRequest, Response};
pub fn handle_ast_replace(req: &RawRequest, ctx: &AppContext) -> Response {
let pattern = match req.params.get("pattern").and_then(|v| v.as_str()) {
Some(p) => p.to_string(),
None => {
return Response::error(
&req.id,
"invalid_request",
"ast_replace: missing required param 'pattern'",
);
}
};
let rewrite = match req.params.get("rewrite").and_then(|v| v.as_str()) {
Some(r) => r.to_string(),
None => {
return Response::error(
&req.id,
"invalid_request",
"ast_replace: missing required param 'rewrite'",
);
}
};
let lang_str = match req.params.get("lang").and_then(|v| v.as_str()) {
Some(l) => l,
None => {
return Response::error(
&req.id,
"invalid_request",
"ast_replace: missing required param 'lang'",
);
}
};
let lang = match AstGrepLang::from_str(lang_str) {
Some(l) => l,
None => {
return Response::error(
&req.id,
"invalid_request",
format!(
"ast_replace: unsupported language '{}'. Supported: typescript, tsx, javascript, python, rust, go, c, cpp, zig, csharp",
lang_str
),
);
}
};
let paths: Vec<String> = req
.params
.get("paths")
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(|s| s.to_string()))
.collect()
})
.unwrap_or_default();
let globs: Vec<String> = req
.params
.get("globs")
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(|s| s.to_string()))
.collect()
})
.unwrap_or_default();
let dry_run = req
.params
.get("dry_run")
.and_then(|v| v.as_bool())
.unwrap_or(true);
let project_root = ctx
.config()
.project_root
.clone()
.unwrap_or_else(|| PathBuf::from("."));
{
use ast_grep_core::matcher::Pattern as AstPattern;
if let Err(e) = AstPattern::try_new(&pattern, lang.clone()) {
return Response::error(
&req.id,
"invalid_pattern",
format!(
"ast_replace: invalid pattern '{}': {}. Patterns must be complete AST nodes.",
pattern, e
),
);
}
}
let files = collect_files(&project_root, &lang, &paths, &globs);
let mut file_results: Vec<serde_json::Value> = Vec::new();
let mut total_replacements = 0usize;
let mut total_files = 0usize;
let mut files_searched = 0usize;
let mut files_with_matches = 0usize;
for file_path in &files {
files_searched += 1;
let original = match std::fs::read_to_string(file_path.as_path()) {
Ok(s) => s,
Err(_) => continue,
};
let replacement_count = {
let root = lang.ast_grep(&original);
let node = root.root();
node.find_all(pattern.as_str()).count()
};
if replacement_count == 0 {
continue;
}
files_with_matches += 1;
let root = lang.ast_grep(&original);
let mut edits = root.root().replace_all(pattern.as_str(), rewrite.as_str());
if edits.is_empty() {
continue;
}
edits.sort_by(|a, b| b.position.cmp(&a.position));
let mut new_bytes = original.as_bytes().to_vec();
for edit in &edits {
let start = edit.position;
let end = start + edit.deleted_length;
if start <= new_bytes.len() && end <= new_bytes.len() {
new_bytes.splice(start..end, edit.inserted_text.iter().copied());
}
}
let new_content = String::from_utf8(new_bytes).unwrap_or_else(|_| original.clone());
total_replacements += replacement_count;
total_files += 1;
if dry_run {
let diff_result = dry_run_diff(&original, &new_content, file_path.as_path());
file_results.push(serde_json::json!({
"file": file_path.display().to_string(),
"diff": diff_result.diff,
"replacements": replacement_count,
}));
} else {
let validated_path = match validate_matched_file_path(ctx, &req.id, file_path.as_path())
{
Ok(path) => path,
Err(resp) => return resp,
};
let backup_id = ctx
.backup()
.borrow_mut()
.snapshot(validated_path.as_path(), "ast_replace")
.ok();
match std::fs::write(validated_path.as_path(), &new_content) {
Ok(()) => {
let mut entry = serde_json::json!({
"file": file_path.display().to_string(),
"replacements": replacement_count,
});
if let Some(bid) = backup_id {
if let Some(obj) = entry.as_object_mut() {
obj.insert("backup_id".to_string(), serde_json::Value::String(bid));
}
}
file_results.push(entry);
}
Err(e) => {
file_results.push(serde_json::json!({
"file": file_path.display().to_string(),
"ok": false,
"error": e.to_string(),
}));
}
}
}
}
Response::success(
&req.id,
serde_json::json!({
"files": file_results,
"total_replacements": total_replacements,
"total_files": total_files,
"files_with_matches": files_with_matches,
"files_searched": files_searched,
"dry_run": dry_run,
}),
)
}
fn validate_matched_file_path(
ctx: &AppContext,
req_id: &str,
file_path: &Path,
) -> Result<PathBuf, Response> {
ctx.validate_path(req_id, file_path)
}
fn collect_files(
project_root: &Path,
lang: &AstGrepLang,
paths: &[String],
globs: &[String],
) -> Vec<PathBuf> {
use ignore::WalkBuilder;
let mut override_builder = ignore::overrides::OverrideBuilder::new(project_root);
for g in globs {
let _ = override_builder.add(g);
}
let overrides = override_builder.build().ok();
let roots: Vec<PathBuf> = if paths.is_empty() {
vec![project_root.to_path_buf()]
} else {
paths
.iter()
.map(|p| {
let pb = PathBuf::from(p);
if pb.is_absolute() {
pb
} else {
project_root.join(p)
}
})
.collect()
};
let mut result = Vec::new();
for root in &roots {
let mut builder = WalkBuilder::new(root);
builder
.hidden(true)
.git_ignore(true)
.git_global(true)
.git_exclude(true);
if let Some(ref ov) = overrides {
builder.overrides(ov.clone());
}
builder.filter_entry(|entry| {
let name = entry.file_name().to_string_lossy();
if entry.file_type().map_or(false, |ft| ft.is_dir()) {
return !matches!(
name.as_ref(),
"node_modules"
| "target"
| "venv"
| ".venv"
| ".git"
| "__pycache__"
| ".tox"
| "dist"
| "build"
);
}
true
});
for entry in builder.build().filter_map(|e| e.ok()) {
if entry.file_type().map_or(false, |ft| ft.is_file()) {
let path = entry.into_path();
if lang.matches_path(&path) {
result.push(path);
}
}
}
}
result
}