#![cfg_attr(test, allow(clippy::items_after_test_module))]
use std::path::Path;
use crate::config::Config;
use crate::context::AppContext;
use crate::error::AftError;
use crate::format;
use crate::parser::{detect_language, grammar_for, FileParser};
pub fn line_col_to_byte(source: &str, line: u32, col: u32) -> usize {
let bytes = source.as_bytes();
let target_line = line as usize;
let mut current_line = 0usize;
let mut line_start = 0usize;
loop {
let mut line_end = line_start;
while line_end < bytes.len() && bytes[line_end] != b'\n' && bytes[line_end] != b'\r' {
line_end += 1;
}
if current_line == target_line {
return line_start + (col as usize).min(line_end.saturating_sub(line_start));
}
if line_end >= bytes.len() {
return source.len();
}
line_start = if bytes[line_end] == b'\r'
&& line_end + 1 < bytes.len()
&& bytes[line_end + 1] == b'\n'
{
line_end + 2
} else {
line_end + 1
};
current_line += 1;
}
}
pub fn replace_byte_range(
source: &str,
start: usize,
end: usize,
replacement: &str,
) -> Result<String, AftError> {
if start > end {
return Err(AftError::InvalidRequest {
message: format!(
"invalid byte range [{}..{}): start must be <= end",
start, end
),
});
}
if end > source.len() {
return Err(AftError::InvalidRequest {
message: format!(
"invalid byte range [{}..{}): end exceeds source length {}",
start,
end,
source.len()
),
});
}
if !source.is_char_boundary(start) {
return Err(AftError::InvalidRequest {
message: format!(
"invalid byte range [{}..{}): start is not a char boundary",
start, end
),
});
}
if !source.is_char_boundary(end) {
return Err(AftError::InvalidRequest {
message: format!(
"invalid byte range [{}..{}): end is not a char boundary",
start, end
),
});
}
let mut result = String::with_capacity(
source.len().saturating_sub(end.saturating_sub(start)) + replacement.len(),
);
result.push_str(&source[..start]);
result.push_str(replacement);
result.push_str(&source[end..]);
Ok(result)
}
pub fn validate_syntax(path: &Path) -> Result<Option<bool>, AftError> {
let mut parser = FileParser::new();
match parser.parse(path) {
Ok((tree, _lang)) => Ok(Some(!tree.root_node().has_error())),
Err(AftError::InvalidRequest { .. }) => {
Ok(None)
}
Err(e) => Err(e),
}
}
pub fn validate_syntax_str(content: &str, path: &Path) -> Option<bool> {
let lang = detect_language(path)?;
let grammar = grammar_for(lang);
let mut parser = tree_sitter::Parser::new();
if parser.set_language(&grammar).is_err() {
return None;
}
let tree = parser.parse(content.as_bytes(), None)?;
Some(!tree.root_node().has_error())
}
pub fn wants_diff(params: &serde_json::Value) -> bool {
params
.get("include_diff")
.and_then(|v| v.as_bool())
.unwrap_or(false)
|| wants_diff_content(params)
}
pub fn wants_diff_content(params: &serde_json::Value) -> bool {
params
.get("include_diff_content")
.and_then(|v| v.as_bool())
.unwrap_or(false)
}
pub fn wants_preview(params: &serde_json::Value) -> bool {
params
.get("preview")
.and_then(|v| v.as_bool())
.unwrap_or(false)
}
pub fn build_unified_diff(file: &str, before: &str, after: &str) -> String {
if before == after {
return format!(
"Index: {file}
===================================================================
--- {file}
+++ {file}
"
);
}
let text_diff = similar::TextDiff::from_lines(before, after);
let patch = text_diff.unified_diff().header(file, file).to_string();
format!(
"Index: {file}
===================================================================
{patch}"
)
}
pub fn attach_preview_diff(
result: &mut serde_json::Value,
params: &serde_json::Value,
file: &str,
before: &str,
after: &str,
) {
result["preview"] = serde_json::json!(true);
result["diff"] = compute_diff_for_response(params, before, after);
result["preview_diff"] = serde_json::json!(build_unified_diff(file, before, after));
}
fn diff_counts(before: &str, after: &str) -> (usize, usize) {
use similar::ChangeTag;
let diff = similar::TextDiff::from_lines(before, after);
let mut additions = 0usize;
let mut deletions = 0usize;
for change in diff.iter_all_changes() {
match change.tag() {
ChangeTag::Insert => additions += 1,
ChangeTag::Delete => deletions += 1,
ChangeTag::Equal => {}
}
}
(additions, deletions)
}
pub fn compute_diff_counts(before: &str, after: &str) -> serde_json::Value {
let (additions, deletions) = diff_counts(before, after);
serde_json::json!({
"additions": additions,
"deletions": deletions,
})
}
pub fn compute_diff_for_response(
params: &serde_json::Value,
before: &str,
after: &str,
) -> serde_json::Value {
if wants_diff_content(params) {
compute_diff_info(before, after)
} else {
compute_diff_counts(before, after)
}
}
pub fn compute_diff_info(before: &str, after: &str) -> serde_json::Value {
let (additions, deletions) = diff_counts(before, after);
let size_limit = 512 * 1024; if before.len() > size_limit || after.len() > size_limit {
serde_json::json!({
"additions": additions,
"deletions": deletions,
"truncated": true,
})
} else {
serde_json::json!({
"before": before,
"after": after,
"additions": additions,
"deletions": deletions,
})
}
}
pub fn auto_backup(
ctx: &AppContext,
session: &str,
path: &Path,
description: &str,
op_id: Option<&str>,
) -> Result<Option<String>, AftError> {
if std::fs::symlink_metadata(path).is_err() {
return Ok(None);
}
let backup_id = {
let mut store = ctx.backup().lock();
store.snapshot_with_op(session, path, description, op_id)?
}; Ok(backup_id)
}
pub struct ReformattedExcerpt {
pub text: String,
pub extensive: bool,
}
const REFORMATTED_EXCERPT_MAX_LINES: usize = 60;
const REFORMATTED_EXCERPT_MAX_BYTES: usize = 4096;
pub fn compute_reformatted_excerpt(
pre_format: &str,
post_format: &str,
) -> Option<ReformattedExcerpt> {
if pre_format == post_format {
return None;
}
use similar::DiffTag;
let diff = similar::TextDiff::from_lines(pre_format, post_format);
let post_lines: Vec<&str> = post_format.lines().collect();
let mut collected: Vec<String> = Vec::new();
let mut last_post_idx: Option<usize> = None;
for group in diff.grouped_ops(2) {
let mut group_start: Option<usize> = None;
let mut group_end: Option<usize> = None;
for op in group {
let tag = op.tag();
if tag == DiffTag::Delete {
continue;
}
let new_range = op.new_range();
if new_range.is_empty() {
continue;
}
let start = new_range.start;
let end = new_range.end.saturating_sub(1);
group_start = Some(group_start.map_or(start, |s| s.min(start)));
group_end = Some(group_end.map_or(end, |e| e.max(end)));
}
let (Some(start), Some(end)) = (group_start, group_end) else {
continue;
};
if let Some(prev) = last_post_idx {
if start > prev + 1 {
collected.push("…".to_string());
}
}
for idx in start..=end {
if idx < post_lines.len() {
collected.push(post_lines[idx].to_string());
}
}
last_post_idx = Some(end);
}
let line_count = collected.len();
let byte_count: usize = collected.iter().map(|l| l.len() + 1).sum();
if line_count > REFORMATTED_EXCERPT_MAX_LINES || byte_count > REFORMATTED_EXCERPT_MAX_BYTES {
return Some(ReformattedExcerpt {
text: String::new(),
extensive: true,
});
}
Some(ReformattedExcerpt {
text: collected.join("\n"),
extensive: false,
})
}
pub struct WriteResult {
pub syntax_valid: Option<bool>,
pub formatted: bool,
pub format_skipped_reason: Option<String>,
pub validate_requested: bool,
pub validation_errors: Vec<format::ValidationError>,
pub validate_skipped_reason: Option<String>,
pub rolled_back: bool,
pub lsp_outcome: Option<crate::lsp::manager::PostEditWaitOutcome>,
pub reformatted_excerpt: Option<ReformattedExcerpt>,
}
pub fn format_validation_errors(errors: &[format::ValidationError]) -> String {
errors
.iter()
.map(|e| format!("line {}: {}", e.line, e.message))
.collect::<Vec<_>>()
.join("; ")
}
impl WriteResult {
pub fn append_lsp_diagnostics_to(&self, result: &mut serde_json::Value) {
result["rolled_back"] = serde_json::json!(self.rolled_back);
let Some(outcome) = self.lsp_outcome.as_ref() else {
return;
};
result["lsp_diagnostics"] = serde_json::json!(outcome
.diagnostics
.iter()
.map(|d| {
serde_json::json!({
"file": d.file.display().to_string(),
"line": d.line,
"column": d.column,
"end_line": d.end_line,
"end_column": d.end_column,
"severity": d.severity.as_str(),
"message": d.message,
"code": d.code,
"source": d.source,
})
})
.collect::<Vec<_>>());
result["lsp_complete"] = serde_json::Value::Bool(outcome.complete());
if !outcome.pending_servers.is_empty() {
result["lsp_pending_servers"] = serde_json::json!(outcome
.pending_servers
.iter()
.map(|key| key.kind.id_str().to_string())
.collect::<Vec<_>>());
}
if !outcome.exited_servers.is_empty() {
result["lsp_exited_servers"] = serde_json::json!(outcome
.exited_servers
.iter()
.map(|key| key.kind.id_str().to_string())
.collect::<Vec<_>>());
}
}
pub fn append_reformatted_excerpt_to(&self, result: &mut serde_json::Value) {
if let Some(excerpt) = &self.reformatted_excerpt {
if excerpt.extensive {
result["reformatted"] = serde_json::json!({ "extensive": true });
} else {
result["reformatted"] = serde_json::json!({ "text": excerpt.text });
}
}
}
}
pub fn write_format_validate(
path: &Path,
content: &str,
config: &Config,
params: &serde_json::Value,
) -> Result<WriteResult, AftError> {
let pre_write_content = if path.exists() {
std::fs::read_to_string(path).ok()
} else {
None
};
let was_syntax_valid = if pre_write_content.is_some() {
match validate_syntax(path) {
Ok(valid) => valid,
Err(_) => None,
}
} else {
None
};
std::fs::write(path, content).map_err(|e| AftError::InvalidRequest {
message: format!("failed to write file: {}", e),
})?;
let (formatted, format_skipped_reason) = format::auto_format(path, config);
let syntax_valid = match validate_syntax(path) {
Ok(sv) => sv,
Err(_) => None,
};
let rolled_back = if was_syntax_valid == Some(true) && syntax_valid == Some(false) {
if let Some(original) = pre_write_content.as_ref() {
std::fs::write(path, original).map_err(|e| AftError::InvalidRequest {
message: format!("failed to roll back invalid edit: {}", e),
})?;
true
} else {
false
}
} else {
false
};
let param_validate = params.get("validate").and_then(|v| v.as_str());
let config_validate = config.validate_on_edit.as_deref();
let validate_mode = param_validate.or(config_validate).unwrap_or("off");
let validate_requested = validate_mode == "full";
let (validation_errors, validate_skipped_reason) = if validate_requested {
format::validate_full(path, config)
} else {
(Vec::new(), None)
};
let reformatted_excerpt = if rolled_back {
None
} else {
std::fs::read_to_string(path)
.ok()
.and_then(|final_on_disk| compute_reformatted_excerpt(content, &final_on_disk))
};
Ok(WriteResult {
syntax_valid,
formatted,
format_skipped_reason,
validate_requested,
validation_errors,
validate_skipped_reason,
rolled_back,
lsp_outcome: None,
reformatted_excerpt,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn line_col_to_byte_empty_string() {
assert_eq!(line_col_to_byte("", 0, 0), 0);
}
#[test]
fn line_col_to_byte_single_line() {
let source = "hello";
assert_eq!(line_col_to_byte(source, 0, 0), 0);
assert_eq!(line_col_to_byte(source, 0, 3), 3);
assert_eq!(line_col_to_byte(source, 0, 5), 5); }
#[test]
fn line_col_to_byte_multi_line() {
let source = "abc\ndef\nghi\n";
assert_eq!(line_col_to_byte(source, 0, 0), 0);
assert_eq!(line_col_to_byte(source, 0, 2), 2);
assert_eq!(line_col_to_byte(source, 1, 0), 4);
assert_eq!(line_col_to_byte(source, 1, 3), 7);
assert_eq!(line_col_to_byte(source, 2, 0), 8);
assert_eq!(line_col_to_byte(source, 2, 2), 10);
}
#[test]
fn line_col_to_byte_last_line_no_trailing_newline() {
let source = "abc\ndef";
assert_eq!(line_col_to_byte(source, 1, 0), 4);
assert_eq!(line_col_to_byte(source, 1, 3), 7); }
#[test]
fn line_col_to_byte_multi_byte_utf8() {
let source = "café\nbar";
assert_eq!(line_col_to_byte(source, 0, 0), 0);
assert_eq!(line_col_to_byte(source, 0, 5), 5); assert_eq!(line_col_to_byte(source, 1, 0), 6);
assert_eq!(line_col_to_byte(source, 1, 2), 8);
}
#[test]
fn line_col_to_byte_beyond_end() {
let source = "abc";
assert_eq!(line_col_to_byte(source, 5, 0), source.len());
}
#[test]
fn line_col_to_byte_col_clamped_to_line_length() {
let source = "ab\ncd";
assert_eq!(line_col_to_byte(source, 0, 10), 2);
}
#[test]
fn line_col_to_byte_crlf() {
let source = "abc\r\ndef\r\nghi\r\n";
assert_eq!(line_col_to_byte(source, 0, 0), 0);
assert_eq!(line_col_to_byte(source, 0, 10), 3);
assert_eq!(line_col_to_byte(source, 1, 0), 5);
assert_eq!(line_col_to_byte(source, 1, 3), 8);
assert_eq!(line_col_to_byte(source, 2, 0), 10);
}
#[test]
fn replace_byte_range_basic() {
let source = "hello world";
let result = replace_byte_range(source, 6, 11, "rust").unwrap();
assert_eq!(result, "hello rust");
}
#[test]
fn replace_byte_range_delete() {
let source = "hello world";
let result = replace_byte_range(source, 5, 11, "").unwrap();
assert_eq!(result, "hello");
}
#[test]
fn replace_byte_range_insert_at_same_position() {
let source = "helloworld";
let result = replace_byte_range(source, 5, 5, " ").unwrap();
assert_eq!(result, "hello world");
}
#[test]
fn replace_byte_range_replace_entire_string() {
let source = "old content";
let result = replace_byte_range(source, 0, source.len(), "new content").unwrap();
assert_eq!(result, "new content");
}
#[test]
fn compute_reformatted_excerpt_self_suppresses_when_unchanged() {
let s = "fn main() {\n let x = 1;\n}\n";
assert!(compute_reformatted_excerpt(s, s).is_none());
}
#[test]
fn compute_reformatted_excerpt_includes_post_format_text() {
let before = "fn main( ){ let x=1; }";
let after = "fn main() {\n let x = 1;\n}\n";
let excerpt = compute_reformatted_excerpt(before, after).expect("should diff");
assert!(!excerpt.extensive);
assert!(excerpt.text.contains("fn main()"));
assert!(excerpt.text.contains("let x = 1"));
}
#[test]
fn compute_reformatted_excerpt_extensive_when_over_line_cap() {
let before: String = (0..80).map(|i| format!("line{i} ugly\n")).collect();
let after: String = (0..80).map(|i| format!("line{i} neat\n")).collect();
let excerpt = compute_reformatted_excerpt(&before, &after).expect("should diff");
assert!(excerpt.extensive);
assert!(excerpt.text.is_empty());
}
#[test]
fn validate_syntax_str_accepts_reference_to_variable_named_raw() {
let path = Path::new("lib.rs");
let src = "fn handle_hash(x: &u32) -> u32 { *x }\n\
fn main() {\n let raw = 5u32;\n let _ = handle_hash(&raw);\n}\n";
assert_eq!(validate_syntax_str(src, path), Some(true));
}
#[test]
fn validate_syntax_str_accepts_raw_borrow_operators() {
let path = Path::new("lib.rs");
let const_borrow = "fn main() {\n let x = 5u32;\n let _p = &raw const x;\n}\n";
let mut_borrow = "fn main() {\n let mut x = 5u32;\n let _p = &raw mut x;\n}\n";
assert_eq!(validate_syntax_str(const_borrow, path), Some(true));
assert_eq!(validate_syntax_str(mut_borrow, path), Some(true));
}
}