1#![cfg_attr(test, allow(clippy::items_after_test_module))]
7
8use std::path::Path;
9
10use crate::config::Config;
11use crate::context::AppContext;
12use crate::error::AftError;
13use crate::format;
14use crate::parser::{detect_language, grammar_for, FileParser};
15
16pub fn line_col_to_byte(source: &str, line: u32, col: u32) -> usize {
24 let bytes = source.as_bytes();
25 let target_line = line as usize;
26 let mut current_line = 0usize;
27 let mut line_start = 0usize;
28
29 loop {
30 let mut line_end = line_start;
31 while line_end < bytes.len() && bytes[line_end] != b'\n' && bytes[line_end] != b'\r' {
32 line_end += 1;
33 }
34
35 if current_line == target_line {
36 return line_start + (col as usize).min(line_end.saturating_sub(line_start));
37 }
38
39 if line_end >= bytes.len() {
40 return source.len();
41 }
42
43 line_start = if bytes[line_end] == b'\r'
44 && line_end + 1 < bytes.len()
45 && bytes[line_end + 1] == b'\n'
46 {
47 line_end + 2
48 } else {
49 line_end + 1
50 };
51 current_line += 1;
52 }
53}
54
55pub fn replace_byte_range(
59 source: &str,
60 start: usize,
61 end: usize,
62 replacement: &str,
63) -> Result<String, AftError> {
64 if start > end {
65 return Err(AftError::InvalidRequest {
66 message: format!(
67 "invalid byte range [{}..{}): start must be <= end",
68 start, end
69 ),
70 });
71 }
72 if end > source.len() {
73 return Err(AftError::InvalidRequest {
74 message: format!(
75 "invalid byte range [{}..{}): end exceeds source length {}",
76 start,
77 end,
78 source.len()
79 ),
80 });
81 }
82 if !source.is_char_boundary(start) {
83 return Err(AftError::InvalidRequest {
84 message: format!(
85 "invalid byte range [{}..{}): start is not a char boundary",
86 start, end
87 ),
88 });
89 }
90 if !source.is_char_boundary(end) {
91 return Err(AftError::InvalidRequest {
92 message: format!(
93 "invalid byte range [{}..{}): end is not a char boundary",
94 start, end
95 ),
96 });
97 }
98
99 let mut result = String::with_capacity(
100 source.len().saturating_sub(end.saturating_sub(start)) + replacement.len(),
101 );
102 result.push_str(&source[..start]);
103 result.push_str(replacement);
104 result.push_str(&source[end..]);
105 Ok(result)
106}
107
108pub fn validate_syntax(path: &Path) -> Result<Option<bool>, AftError> {
113 let mut parser = FileParser::new();
114 match parser.parse(path) {
115 Ok((tree, _lang)) => Ok(Some(!tree.root_node().has_error())),
116 Err(AftError::InvalidRequest { .. }) => {
117 Ok(None)
119 }
120 Err(e) => Err(e),
121 }
122}
123
124pub fn validate_syntax_str(content: &str, path: &Path) -> Option<bool> {
130 let lang = detect_language(path)?;
131 let grammar = grammar_for(lang);
132 let mut parser = tree_sitter::Parser::new();
133 if parser.set_language(&grammar).is_err() {
134 return None;
135 }
136 let tree = parser.parse(content.as_bytes(), None)?;
137 Some(!tree.root_node().has_error())
138}
139
140pub fn wants_diff(params: &serde_json::Value) -> bool {
142 params
143 .get("include_diff")
144 .and_then(|v| v.as_bool())
145 .unwrap_or(false)
146}
147
148pub fn compute_diff_info(before: &str, after: &str) -> serde_json::Value {
152 use similar::ChangeTag;
153
154 let diff = similar::TextDiff::from_lines(before, after);
155 let mut additions = 0usize;
156 let mut deletions = 0usize;
157 for change in diff.iter_all_changes() {
158 match change.tag() {
159 ChangeTag::Insert => additions += 1,
160 ChangeTag::Delete => deletions += 1,
161 ChangeTag::Equal => {}
162 }
163 }
164
165 let size_limit = 512 * 1024; if before.len() > size_limit || after.len() > size_limit {
168 serde_json::json!({
169 "additions": additions,
170 "deletions": deletions,
171 "truncated": true,
172 })
173 } else {
174 serde_json::json!({
175 "before": before,
176 "after": after,
177 "additions": additions,
178 "deletions": deletions,
179 })
180 }
181}
182pub fn auto_backup(
194 ctx: &AppContext,
195 session: &str,
196 path: &Path,
197 description: &str,
198) -> Result<Option<String>, AftError> {
199 if !path.exists() {
200 return Ok(None);
201 }
202 let backup_id = {
203 let mut store = ctx.backup().borrow_mut();
204 store.snapshot(session, path, description)?
205 }; Ok(Some(backup_id))
207}
208
209pub struct WriteResult {
214 pub syntax_valid: Option<bool>,
216 pub formatted: bool,
218 pub format_skipped_reason: Option<String>,
222 pub validate_requested: bool,
224 pub validation_errors: Vec<format::ValidationError>,
226 pub validate_skipped_reason: Option<String>,
229 pub rolled_back: bool,
234 pub lsp_outcome: Option<crate::lsp::manager::PostEditWaitOutcome>,
243}
244
245impl WriteResult {
246 pub fn append_lsp_diagnostics_to(&self, result: &mut serde_json::Value) {
261 result["rolled_back"] = serde_json::json!(self.rolled_back);
262
263 let Some(outcome) = self.lsp_outcome.as_ref() else {
264 return;
265 };
266
267 result["lsp_diagnostics"] = serde_json::json!(outcome
268 .diagnostics
269 .iter()
270 .map(|d| {
271 serde_json::json!({
272 "file": d.file.display().to_string(),
273 "line": d.line,
274 "column": d.column,
275 "end_line": d.end_line,
276 "end_column": d.end_column,
277 "severity": d.severity.as_str(),
278 "message": d.message,
279 "code": d.code,
280 "source": d.source,
281 })
282 })
283 .collect::<Vec<_>>());
284
285 result["lsp_complete"] = serde_json::Value::Bool(outcome.complete());
286
287 if !outcome.pending_servers.is_empty() {
288 result["lsp_pending_servers"] = serde_json::json!(outcome
289 .pending_servers
290 .iter()
291 .map(|key| key.kind.id_str().to_string())
292 .collect::<Vec<_>>());
293 }
294 if !outcome.exited_servers.is_empty() {
295 result["lsp_exited_servers"] = serde_json::json!(outcome
296 .exited_servers
297 .iter()
298 .map(|key| key.kind.id_str().to_string())
299 .collect::<Vec<_>>());
300 }
301 }
302}
303
304pub fn write_format_validate(
317 path: &Path,
318 content: &str,
319 config: &Config,
320 params: &serde_json::Value,
321) -> Result<WriteResult, AftError> {
322 let pre_write_content = if path.exists() {
323 std::fs::read_to_string(path).ok()
324 } else {
325 None
326 };
327 let was_syntax_valid = if pre_write_content.is_some() {
331 match validate_syntax(path) {
332 Ok(valid) => valid,
333 Err(_) => None,
334 }
335 } else {
336 None
337 };
338
339 std::fs::write(path, content).map_err(|e| AftError::InvalidRequest {
341 message: format!("failed to write file: {}", e),
342 })?;
343
344 let (formatted, format_skipped_reason) = format::auto_format(path, config);
346
347 let syntax_valid = match validate_syntax(path) {
349 Ok(sv) => sv,
350 Err(_) => None,
351 };
352 let rolled_back = if was_syntax_valid == Some(true) && syntax_valid == Some(false) {
353 if let Some(original) = pre_write_content.as_ref() {
354 std::fs::write(path, original).map_err(|e| AftError::InvalidRequest {
355 message: format!("failed to roll back invalid edit: {}", e),
356 })?;
357 true
358 } else {
359 false
360 }
361 } else {
362 false
363 };
364
365 let param_validate = params.get("validate").and_then(|v| v.as_str());
367 let config_validate = config.validate_on_edit.as_deref();
368 let validate_mode = param_validate.or(config_validate).unwrap_or("off");
370 let validate_requested = validate_mode == "full";
371 let (validation_errors, validate_skipped_reason) = if validate_requested {
372 format::validate_full(path, config)
373 } else {
374 (Vec::new(), None)
375 };
376
377 Ok(WriteResult {
378 syntax_valid,
379 formatted,
380 format_skipped_reason,
381 validate_requested,
382 validation_errors,
383 validate_skipped_reason,
384 rolled_back,
385 lsp_outcome: None,
386 })
387}
388
389#[cfg(test)]
390mod tests {
391 use super::*;
392
393 #[test]
396 fn line_col_to_byte_empty_string() {
397 assert_eq!(line_col_to_byte("", 0, 0), 0);
398 }
399
400 #[test]
401 fn line_col_to_byte_single_line() {
402 let source = "hello";
403 assert_eq!(line_col_to_byte(source, 0, 0), 0);
404 assert_eq!(line_col_to_byte(source, 0, 3), 3);
405 assert_eq!(line_col_to_byte(source, 0, 5), 5); }
407
408 #[test]
409 fn line_col_to_byte_multi_line() {
410 let source = "abc\ndef\nghi\n";
411 assert_eq!(line_col_to_byte(source, 0, 0), 0);
413 assert_eq!(line_col_to_byte(source, 0, 2), 2);
414 assert_eq!(line_col_to_byte(source, 1, 0), 4);
416 assert_eq!(line_col_to_byte(source, 1, 3), 7);
417 assert_eq!(line_col_to_byte(source, 2, 0), 8);
419 assert_eq!(line_col_to_byte(source, 2, 2), 10);
420 }
421
422 #[test]
423 fn line_col_to_byte_last_line_no_trailing_newline() {
424 let source = "abc\ndef";
425 assert_eq!(line_col_to_byte(source, 1, 0), 4);
427 assert_eq!(line_col_to_byte(source, 1, 3), 7); }
429
430 #[test]
431 fn line_col_to_byte_multi_byte_utf8() {
432 let source = "café\nbar";
434 assert_eq!(line_col_to_byte(source, 0, 0), 0);
436 assert_eq!(line_col_to_byte(source, 0, 5), 5); assert_eq!(line_col_to_byte(source, 1, 0), 6);
439 assert_eq!(line_col_to_byte(source, 1, 2), 8);
440 }
441
442 #[test]
443 fn line_col_to_byte_beyond_end() {
444 let source = "abc";
445 assert_eq!(line_col_to_byte(source, 5, 0), source.len());
447 }
448
449 #[test]
450 fn line_col_to_byte_col_clamped_to_line_length() {
451 let source = "ab\ncd";
452 assert_eq!(line_col_to_byte(source, 0, 10), 2);
454 }
455
456 #[test]
457 fn line_col_to_byte_crlf() {
458 let source = "abc\r\ndef\r\nghi\r\n";
459 assert_eq!(line_col_to_byte(source, 0, 0), 0);
460 assert_eq!(line_col_to_byte(source, 0, 10), 3);
461 assert_eq!(line_col_to_byte(source, 1, 0), 5);
462 assert_eq!(line_col_to_byte(source, 1, 3), 8);
463 assert_eq!(line_col_to_byte(source, 2, 0), 10);
464 }
465
466 #[test]
469 fn replace_byte_range_basic() {
470 let source = "hello world";
471 let result = replace_byte_range(source, 6, 11, "rust").unwrap();
472 assert_eq!(result, "hello rust");
473 }
474
475 #[test]
476 fn replace_byte_range_delete() {
477 let source = "hello world";
478 let result = replace_byte_range(source, 5, 11, "").unwrap();
479 assert_eq!(result, "hello");
480 }
481
482 #[test]
483 fn replace_byte_range_insert_at_same_position() {
484 let source = "helloworld";
485 let result = replace_byte_range(source, 5, 5, " ").unwrap();
486 assert_eq!(result, "hello world");
487 }
488
489 #[test]
490 fn replace_byte_range_replace_entire_string() {
491 let source = "old content";
492 let result = replace_byte_range(source, 0, source.len(), "new content").unwrap();
493 assert_eq!(result, "new content");
494 }
495}