#![cfg_attr(test, allow(clippy::items_after_test_module))]
use std::path::Path;
use crate::config::Config;
use crate::context::AppContext;
use crate::error::AftError;
use crate::format;
use crate::parser::{detect_language, grammar_for, FileParser};
pub fn line_col_to_byte(source: &str, line: u32, col: u32) -> usize {
let bytes = source.as_bytes();
let target_line = line as usize;
let mut current_line = 0usize;
let mut line_start = 0usize;
loop {
let mut line_end = line_start;
while line_end < bytes.len() && bytes[line_end] != b'\n' && bytes[line_end] != b'\r' {
line_end += 1;
}
if current_line == target_line {
return line_start + (col as usize).min(line_end.saturating_sub(line_start));
}
if line_end >= bytes.len() {
return source.len();
}
line_start = if bytes[line_end] == b'\r'
&& line_end + 1 < bytes.len()
&& bytes[line_end + 1] == b'\n'
{
line_end + 2
} else {
line_end + 1
};
current_line += 1;
}
}
pub fn replace_byte_range(
source: &str,
start: usize,
end: usize,
replacement: &str,
) -> Result<String, AftError> {
if start > end {
return Err(AftError::InvalidRequest {
message: format!(
"invalid byte range [{}..{}): start must be <= end",
start, end
),
});
}
if end > source.len() {
return Err(AftError::InvalidRequest {
message: format!(
"invalid byte range [{}..{}): end exceeds source length {}",
start,
end,
source.len()
),
});
}
if !source.is_char_boundary(start) {
return Err(AftError::InvalidRequest {
message: format!(
"invalid byte range [{}..{}): start is not a char boundary",
start, end
),
});
}
if !source.is_char_boundary(end) {
return Err(AftError::InvalidRequest {
message: format!(
"invalid byte range [{}..{}): end is not a char boundary",
start, end
),
});
}
let mut result = String::with_capacity(
source.len().saturating_sub(end.saturating_sub(start)) + replacement.len(),
);
result.push_str(&source[..start]);
result.push_str(replacement);
result.push_str(&source[end..]);
Ok(result)
}
pub fn validate_syntax(path: &Path) -> Result<Option<bool>, AftError> {
let mut parser = FileParser::new();
match parser.parse(path) {
Ok((tree, _lang)) => Ok(Some(!tree.root_node().has_error())),
Err(AftError::InvalidRequest { .. }) => {
Ok(None)
}
Err(e) => Err(e),
}
}
pub fn validate_syntax_str(content: &str, path: &Path) -> Option<bool> {
let lang = detect_language(path)?;
let grammar = grammar_for(lang);
let mut parser = tree_sitter::Parser::new();
if parser.set_language(&grammar).is_err() {
return None;
}
let tree = parser.parse(content.as_bytes(), None)?;
Some(!tree.root_node().has_error())
}
pub fn wants_diff(params: &serde_json::Value) -> bool {
params
.get("include_diff")
.and_then(|v| v.as_bool())
.unwrap_or(false)
}
pub fn compute_diff_info(before: &str, after: &str) -> serde_json::Value {
use similar::ChangeTag;
let diff = similar::TextDiff::from_lines(before, after);
let mut additions = 0usize;
let mut deletions = 0usize;
for change in diff.iter_all_changes() {
match change.tag() {
ChangeTag::Insert => additions += 1,
ChangeTag::Delete => deletions += 1,
ChangeTag::Equal => {}
}
}
let size_limit = 512 * 1024; if before.len() > size_limit || after.len() > size_limit {
serde_json::json!({
"additions": additions,
"deletions": deletions,
"truncated": true,
})
} else {
serde_json::json!({
"before": before,
"after": after,
"additions": additions,
"deletions": deletions,
})
}
}
pub fn auto_backup(
ctx: &AppContext,
session: &str,
path: &Path,
description: &str,
) -> Result<Option<String>, AftError> {
if !path.exists() {
return Ok(None);
}
let backup_id = {
let mut store = ctx.backup().borrow_mut();
store.snapshot(session, path, description)?
}; Ok(Some(backup_id))
}
pub struct WriteResult {
pub syntax_valid: Option<bool>,
pub formatted: bool,
pub format_skipped_reason: Option<String>,
pub validate_requested: bool,
pub validation_errors: Vec<format::ValidationError>,
pub validate_skipped_reason: Option<String>,
pub rolled_back: bool,
pub lsp_outcome: Option<crate::lsp::manager::PostEditWaitOutcome>,
}
impl WriteResult {
pub fn append_lsp_diagnostics_to(&self, result: &mut serde_json::Value) {
result["rolled_back"] = serde_json::json!(self.rolled_back);
let Some(outcome) = self.lsp_outcome.as_ref() else {
return;
};
result["lsp_diagnostics"] = serde_json::json!(outcome
.diagnostics
.iter()
.map(|d| {
serde_json::json!({
"file": d.file.display().to_string(),
"line": d.line,
"column": d.column,
"end_line": d.end_line,
"end_column": d.end_column,
"severity": d.severity.as_str(),
"message": d.message,
"code": d.code,
"source": d.source,
})
})
.collect::<Vec<_>>());
result["lsp_complete"] = serde_json::Value::Bool(outcome.complete());
if !outcome.pending_servers.is_empty() {
result["lsp_pending_servers"] = serde_json::json!(outcome
.pending_servers
.iter()
.map(|key| key.kind.id_str().to_string())
.collect::<Vec<_>>());
}
if !outcome.exited_servers.is_empty() {
result["lsp_exited_servers"] = serde_json::json!(outcome
.exited_servers
.iter()
.map(|key| key.kind.id_str().to_string())
.collect::<Vec<_>>());
}
}
}
pub fn write_format_validate(
path: &Path,
content: &str,
config: &Config,
params: &serde_json::Value,
) -> Result<WriteResult, AftError> {
let pre_write_content = if path.exists() {
std::fs::read_to_string(path).ok()
} else {
None
};
let was_syntax_valid = if pre_write_content.is_some() {
match validate_syntax(path) {
Ok(valid) => valid,
Err(_) => None,
}
} else {
None
};
std::fs::write(path, content).map_err(|e| AftError::InvalidRequest {
message: format!("failed to write file: {}", e),
})?;
let (formatted, format_skipped_reason) = format::auto_format(path, config);
let syntax_valid = match validate_syntax(path) {
Ok(sv) => sv,
Err(_) => None,
};
let rolled_back = if was_syntax_valid == Some(true) && syntax_valid == Some(false) {
if let Some(original) = pre_write_content.as_ref() {
std::fs::write(path, original).map_err(|e| AftError::InvalidRequest {
message: format!("failed to roll back invalid edit: {}", e),
})?;
true
} else {
false
}
} else {
false
};
let param_validate = params.get("validate").and_then(|v| v.as_str());
let config_validate = config.validate_on_edit.as_deref();
let validate_mode = param_validate.or(config_validate).unwrap_or("off");
let validate_requested = validate_mode == "full";
let (validation_errors, validate_skipped_reason) = if validate_requested {
format::validate_full(path, config)
} else {
(Vec::new(), None)
};
Ok(WriteResult {
syntax_valid,
formatted,
format_skipped_reason,
validate_requested,
validation_errors,
validate_skipped_reason,
rolled_back,
lsp_outcome: None,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn line_col_to_byte_empty_string() {
assert_eq!(line_col_to_byte("", 0, 0), 0);
}
#[test]
fn line_col_to_byte_single_line() {
let source = "hello";
assert_eq!(line_col_to_byte(source, 0, 0), 0);
assert_eq!(line_col_to_byte(source, 0, 3), 3);
assert_eq!(line_col_to_byte(source, 0, 5), 5); }
#[test]
fn line_col_to_byte_multi_line() {
let source = "abc\ndef\nghi\n";
assert_eq!(line_col_to_byte(source, 0, 0), 0);
assert_eq!(line_col_to_byte(source, 0, 2), 2);
assert_eq!(line_col_to_byte(source, 1, 0), 4);
assert_eq!(line_col_to_byte(source, 1, 3), 7);
assert_eq!(line_col_to_byte(source, 2, 0), 8);
assert_eq!(line_col_to_byte(source, 2, 2), 10);
}
#[test]
fn line_col_to_byte_last_line_no_trailing_newline() {
let source = "abc\ndef";
assert_eq!(line_col_to_byte(source, 1, 0), 4);
assert_eq!(line_col_to_byte(source, 1, 3), 7); }
#[test]
fn line_col_to_byte_multi_byte_utf8() {
let source = "café\nbar";
assert_eq!(line_col_to_byte(source, 0, 0), 0);
assert_eq!(line_col_to_byte(source, 0, 5), 5); assert_eq!(line_col_to_byte(source, 1, 0), 6);
assert_eq!(line_col_to_byte(source, 1, 2), 8);
}
#[test]
fn line_col_to_byte_beyond_end() {
let source = "abc";
assert_eq!(line_col_to_byte(source, 5, 0), source.len());
}
#[test]
fn line_col_to_byte_col_clamped_to_line_length() {
let source = "ab\ncd";
assert_eq!(line_col_to_byte(source, 0, 10), 2);
}
#[test]
fn line_col_to_byte_crlf() {
let source = "abc\r\ndef\r\nghi\r\n";
assert_eq!(line_col_to_byte(source, 0, 0), 0);
assert_eq!(line_col_to_byte(source, 0, 10), 3);
assert_eq!(line_col_to_byte(source, 1, 0), 5);
assert_eq!(line_col_to_byte(source, 1, 3), 8);
assert_eq!(line_col_to_byte(source, 2, 0), 10);
}
#[test]
fn replace_byte_range_basic() {
let source = "hello world";
let result = replace_byte_range(source, 6, 11, "rust").unwrap();
assert_eq!(result, "hello rust");
}
#[test]
fn replace_byte_range_delete() {
let source = "hello world";
let result = replace_byte_range(source, 5, 11, "").unwrap();
assert_eq!(result, "hello");
}
#[test]
fn replace_byte_range_insert_at_same_position() {
let source = "helloworld";
let result = replace_byte_range(source, 5, 5, " ").unwrap();
assert_eq!(result, "hello world");
}
#[test]
fn replace_byte_range_replace_entire_string() {
let source = "old content";
let result = replace_byte_range(source, 0, source.len(), "new content").unwrap();
assert_eq!(result, "new content");
}
}