use std::collections::HashSet;
use std::path::PathBuf;
use crate::context::AppContext;
use crate::edit;
use crate::lsp::manager::PostEditWaitOutcome;
use crate::protocol::{RawRequest, Response};
struct ParsedOp {
file: PathBuf,
kind: OpKind,
}
enum OpKind {
Write {
content: String,
},
EditMatch {
match_str: String,
replacement: String,
},
}
struct FileResult {
file: String,
syntax_valid: Option<bool>,
formatted: bool,
format_skipped_reason: Option<String>,
}
struct TransactionWriteResult {
file_result: FileResult,
file: PathBuf,
}
struct TransactionOpError {
code: &'static str,
message: String,
}
pub fn handle_transaction(req: &RawRequest, ctx: &AppContext) -> Response {
let operations = match req.params.get("operations").and_then(|v| v.as_array()) {
Some(ops) => ops,
None => {
return Response::error(
&req.id,
"invalid_request",
"transaction: missing required param 'operations' (expected array)",
);
}
};
if operations.is_empty() {
return Response::error(
&req.id,
"invalid_request",
"transaction: 'operations' array must not be empty",
);
}
let mut parsed: Vec<ParsedOp> = Vec::with_capacity(operations.len());
for (i, op) in operations.iter().enumerate() {
match parse_operation(op, i) {
Ok(p) => match validate_operation_path(p, &req.id, ctx) {
Ok(validated) => parsed.push(validated),
Err(resp) => return resp,
},
Err(msg) => return Response::error(&req.id, "invalid_request", msg),
}
}
let dry_run = edit::is_dry_run(&req.params);
if dry_run {
return handle_dry_run(req, &parsed);
}
let mut snapshotted_files: Vec<PathBuf> = Vec::new();
let mut new_files: Vec<PathBuf> = Vec::new();
let mut seen: HashSet<PathBuf> = HashSet::new();
for op in &parsed {
if seen.contains(&op.file) {
continue;
}
seen.insert(op.file.clone());
if op.file.exists() {
let snapshot_result = {
let mut store = ctx.backup().borrow_mut();
store.snapshot(req.session(), &op.file, "transaction")
};
if let Err(e) = snapshot_result {
return Response::error(&req.id, e.code(), e.to_string());
}
snapshotted_files.push(op.file.clone());
} else {
new_files.push(op.file.clone());
}
}
let multi_file_write_paths: Vec<String> = parsed
.iter()
.map(|op| op.file.display().to_string())
.collect();
let lsp_params = params_with_multi_file_write_paths(&req.params, &multi_file_write_paths);
let mut results: Vec<TransactionWriteResult> = Vec::new();
for (i, op) in parsed.iter().enumerate() {
let new_content = match compute_new_content(op) {
Ok(c) => c,
Err(err) => {
let failures = rollback(ctx, req.session(), &snapshotted_files, &new_files);
return transaction_error(
&req.id,
err.code,
i,
&snapshotted_files,
&new_files,
&err.message,
&failures,
);
}
};
match edit::write_format_validate(&op.file, &new_content, &ctx.config(), &req.params) {
Ok(wr) => {
if !snapshotted_files.contains(&op.file) && !new_files.contains(&op.file) {
new_files.push(op.file.clone());
}
results.push(TransactionWriteResult {
file: op.file.clone(),
file_result: FileResult {
file: op.file.display().to_string(),
syntax_valid: wr.syntax_valid,
formatted: wr.formatted,
format_skipped_reason: wr.format_skipped_reason,
},
});
}
Err(e) => {
let failures = rollback(ctx, req.session(), &snapshotted_files, &new_files);
return transaction_error(
&req.id,
"transaction_failed",
i,
&snapshotted_files,
&new_files,
&format!("write failed: {}", e),
&failures,
);
}
}
}
for (i, result) in results.iter().enumerate() {
if result.file_result.syntax_valid == Some(false) {
let failures = rollback(ctx, req.session(), &snapshotted_files, &new_files);
return transaction_error(
&req.id,
"transaction_failed",
i,
&snapshotted_files,
&new_files,
&format!("syntax error in {}", result.file_result.file),
&failures,
);
}
}
let mut stripped_params = lsp_params.clone();
if let Some(obj) = stripped_params.as_object_mut() {
obj.remove("multi_file_write_paths");
}
let mut lsp_outcomes: Vec<PostEditWaitOutcome> = Vec::new();
for (idx, result) in results.iter().enumerate() {
if let Ok(final_content) = std::fs::read_to_string(&result.file) {
let params_for_call = if idx == 0 {
&lsp_params
} else {
&stripped_params
};
if let Some(lsp_outcome) =
ctx.lsp_post_write(&result.file, &final_content, params_for_call)
{
lsp_outcomes.push(lsp_outcome);
}
}
}
let files_modified = results.len();
let result_json: Vec<serde_json::Value> = results
.iter()
.map(|r| {
let file_result = &r.file_result;
let mut v = serde_json::json!({
"file": file_result.file,
"syntax_valid": file_result.syntax_valid,
"formatted": file_result.formatted,
});
if let Some(ref reason) = file_result.format_skipped_reason {
v["format_skipped_reason"] = serde_json::json!(reason);
}
v
})
.collect();
let lsp_outcome = merge_lsp_outcomes(lsp_outcomes.iter());
log::debug!(
"[aft] transaction: {} files modified successfully",
files_modified
);
let mut result = serde_json::json!({
"ok": true,
"files_modified": files_modified,
"results": result_json,
});
append_lsp_diagnostics_to_transaction(&mut result, lsp_outcome.as_ref());
Response::success(&req.id, result)
}
fn params_with_multi_file_write_paths(
params: &serde_json::Value,
paths: &[String],
) -> serde_json::Value {
let mut params = params.clone();
params["multi_file_write_paths"] = serde_json::json!(paths);
params
}
fn merge_lsp_outcomes<'a>(
outcomes: impl Iterator<Item = &'a PostEditWaitOutcome>,
) -> Option<PostEditWaitOutcome> {
let mut merged: Option<PostEditWaitOutcome> = None;
for outcome in outcomes {
let merged = merged.get_or_insert_with(PostEditWaitOutcome::default);
merged.diagnostics.extend(outcome.diagnostics.clone());
merged
.pending_servers
.extend(outcome.pending_servers.clone());
merged.exited_servers.extend(outcome.exited_servers.clone());
}
merged
}
fn append_lsp_diagnostics_to_transaction(
result: &mut serde_json::Value,
outcome: Option<&PostEditWaitOutcome>,
) {
let Some(outcome) = outcome else {
return;
};
result["lsp_diagnostics"] = serde_json::json!(outcome
.diagnostics
.iter()
.map(|d| {
serde_json::json!({
"file": d.file.display().to_string(),
"line": d.line,
"column": d.column,
"end_line": d.end_line,
"end_column": d.end_column,
"severity": d.severity.as_str(),
"message": d.message,
"code": d.code,
"source": d.source,
})
})
.collect::<Vec<_>>());
result["lsp_complete"] = serde_json::Value::Bool(outcome.complete());
if !outcome.pending_servers.is_empty() {
result["lsp_pending_servers"] = serde_json::json!(outcome
.pending_servers
.iter()
.map(|key| key.kind.id_str().to_string())
.collect::<Vec<_>>());
}
if !outcome.exited_servers.is_empty() {
result["lsp_exited_servers"] = serde_json::json!(outcome
.exited_servers
.iter()
.map(|key| key.kind.id_str().to_string())
.collect::<Vec<_>>());
}
}
fn validate_operation_path(
mut op: ParsedOp,
req_id: &str,
ctx: &AppContext,
) -> Result<ParsedOp, Response> {
op.file = ctx.validate_path(req_id, &op.file)?;
Ok(op)
}
fn parse_operation(op: &serde_json::Value, index: usize) -> Result<ParsedOp, String> {
let file = op
.get("file")
.and_then(|v| v.as_str())
.ok_or_else(|| format!("transaction: operation[{}] missing 'file'", index))?;
let command = op
.get("command")
.and_then(|v| v.as_str())
.ok_or_else(|| format!("transaction: operation[{}] missing 'command'", index))?;
let kind = match command {
"write" => {
let content = op.get("content").and_then(|v| v.as_str()).ok_or_else(|| {
format!(
"transaction: operation[{}] 'write' requires 'content'",
index
)
})?;
OpKind::Write {
content: content.to_string(),
}
}
"edit_match" => {
let match_str = op.get("match").and_then(|v| v.as_str()).ok_or_else(|| {
format!(
"transaction: operation[{}] 'edit_match' requires 'match'",
index
)
})?;
let replacement = op
.get("replacement")
.and_then(|v| v.as_str())
.ok_or_else(|| {
"transaction: edit_match operation requires 'replacement' field".to_string()
})?;
OpKind::EditMatch {
match_str: match_str.to_string(),
replacement: replacement.to_string(),
}
}
other => {
return Err(format!(
"transaction: operation[{}] unknown command '{}' (expected 'write' or 'edit_match')",
index, other
));
}
};
Ok(ParsedOp {
file: PathBuf::from(file),
kind,
})
}
fn compute_new_content(op: &ParsedOp) -> Result<String, TransactionOpError> {
match &op.kind {
OpKind::Write { content } => Ok(content.clone()),
OpKind::EditMatch {
match_str,
replacement,
} => {
let source = if op.file.exists() {
std::fs::read_to_string(&op.file).map_err(|e| TransactionOpError {
code: "transaction_failed",
message: format!("failed to read {}: {}", op.file.display(), e),
})?
} else {
String::new()
};
let (byte_start, byte_len) = find_single_fuzzy_match(&source, match_str, op)?;
edit::replace_byte_range(&source, byte_start, byte_start + byte_len, replacement)
.map_err(|e| TransactionOpError {
code: e.code(),
message: e.to_string(),
})
}
}
}
fn handle_dry_run(req: &RawRequest, ops: &[ParsedOp]) -> Response {
let mut diffs: Vec<serde_json::Value> = Vec::with_capacity(ops.len());
for op in ops {
let original = if op.file.exists() {
std::fs::read_to_string(&op.file).unwrap_or_default()
} else {
String::new()
};
let new_content = match compute_new_content_dry(op, &original) {
Ok(c) => c,
Err(err) => {
return Response::error(&req.id, err.code, err.message);
}
};
let dr = edit::dry_run_diff(&original, &new_content, &op.file);
diffs.push(serde_json::json!({
"file": op.file.display().to_string(),
"diff": dr.diff,
"syntax_valid": dr.syntax_valid,
}));
}
Response::success(
&req.id,
serde_json::json!({
"ok": true,
"dry_run": true,
"diffs": diffs,
}),
)
}
fn compute_new_content_dry(op: &ParsedOp, original: &str) -> Result<String, TransactionOpError> {
match &op.kind {
OpKind::Write { content } => Ok(content.clone()),
OpKind::EditMatch {
match_str,
replacement,
} => {
let (byte_start, byte_len) = find_single_fuzzy_match(original, match_str, op)?;
edit::replace_byte_range(original, byte_start, byte_start + byte_len, replacement)
.map_err(|e| TransactionOpError {
code: e.code(),
message: e.to_string(),
})
}
}
}
fn find_single_fuzzy_match(
source: &str,
match_str: &str,
op: &ParsedOp,
) -> Result<(usize, usize), TransactionOpError> {
let fuzzy_matches = crate::fuzzy_match::find_all_fuzzy(source, match_str);
if fuzzy_matches.is_empty() {
return Err(TransactionOpError {
code: "transaction_failed",
message: format!("match '{}' not found in {}", match_str, op.file.display()),
});
}
if fuzzy_matches.len() > 1 {
return Err(TransactionOpError {
code: "ambiguous_match",
message: format!(
"match '{}' is ambiguous ({} occurrences) in {}",
match_str,
fuzzy_matches.len(),
op.file.display()
),
});
}
let fuzzy_match = &fuzzy_matches[0];
Ok((fuzzy_match.byte_start, fuzzy_match.byte_len))
}
struct RollbackFailure {
file: String,
action: String,
error: String,
}
fn rollback(
ctx: &AppContext,
session: &str,
snapshotted: &[PathBuf],
new_files: &[PathBuf],
) -> Vec<RollbackFailure> {
let mut failures = Vec::new();
for path in snapshotted.iter().rev() {
let result = {
let mut store = ctx.backup().borrow_mut();
store.restore_latest(session, path)
};
if let Err(e) = result {
log::warn!(
"[aft] transaction rollback: failed to restore {}: {}",
path.display(),
e
);
failures.push(RollbackFailure {
file: path.display().to_string(),
action: "restore".to_string(),
error: e.to_string(),
});
}
}
for path in new_files {
if path.exists() {
if let Err(e) = std::fs::remove_file(path) {
log::warn!(
"[aft] transaction rollback: failed to delete new file {}: {}",
path.display(),
e
);
failures.push(RollbackFailure {
file: path.display().to_string(),
action: "delete".to_string(),
error: e.to_string(),
});
}
}
}
failures
}
fn transaction_error(
req_id: &str,
code: &str,
failed_index: usize,
snapshotted: &[PathBuf],
new_files: &[PathBuf],
message: &str,
rollback_failures: &[RollbackFailure],
) -> Response {
let mut rolled_back: Vec<serde_json::Value> = snapshotted
.iter()
.map(|p| serde_json::json!({ "file": p.display().to_string(), "action": "restored" }))
.collect();
for p in new_files {
rolled_back.push(serde_json::json!({
"file": p.display().to_string(),
"action": "deleted",
}));
}
log::debug!(
"[aft] transaction failed at operation[{}]: {} — rolled back {} files",
failed_index,
message,
rolled_back.len()
);
let mut data = serde_json::json!({
"failed_operation": failed_index,
"rolled_back": rolled_back,
});
if !rollback_failures.is_empty() {
data["rollback_failures"] = serde_json::json!(rollback_failures
.iter()
.map(|f| serde_json::json!({
"file": f.file,
"action": f.action,
"error": f.error,
}))
.collect::<Vec<_>>());
}
Response::error_with_data(req_id, code, message, data)
}