use std::ops::Range;
use std::path::Path;
use crate::doc_writer::formats::{doc_format_for, format_doc_comment, InsertionPosition};
use crate::doc_writer::DocCommentResult;
use crate::language::Language;
use crate::parser::Parser;
#[derive(Debug, thiserror::Error)]
pub enum DocWriterError {
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
#[error("Parser error: {0}")]
Parser(#[from] crate::parser::ParserError),
#[error("Function not found in file: {0}")]
FunctionNotFound(String),
}
#[derive(Debug)]
struct ResolvedEdit {
remove_range: Option<Range<usize>>,
insert_at: usize,
new_lines: Vec<String>,
}
pub fn find_insertion_point(line_start: usize, file_lines: &[&str], language: Language) -> usize {
let _span = tracing::debug_span!("find_insertion_point", line_start, %language).entered();
if file_lines.is_empty() || line_start == 0 {
return line_start;
}
let format = doc_format_for(language);
match format.position {
InsertionPosition::InsideBody => {
line_start + 1
}
InsertionPosition::BeforeFunction => {
if line_start <= 1 {
return line_start;
}
let mut idx = line_start - 2; if idx >= file_lines.len() {
return line_start;
}
let mut seen_decorator = false;
loop {
let trimmed = file_lines[idx].trim();
let is_decorator = trimmed.starts_with('@')
|| trimmed.starts_with("#[")
|| trimmed.starts_with("#![")
|| trimmed.starts_with('[');
if is_decorator {
seen_decorator = true;
if idx == 0 {
return 1; }
idx -= 1;
} else if trimmed.is_empty() && seen_decorator {
if idx == 0 {
return 1;
}
idx -= 1;
} else {
return idx + 2; }
}
}
}
}
pub fn detect_existing_doc_range(
insertion_line: usize,
file_lines: &[&str],
language: Language,
) -> Option<Range<usize>> {
let _span =
tracing::debug_span!("detect_existing_doc_range", insertion_line, %language).entered();
let format = doc_format_for(language);
match format.position {
InsertionPosition::InsideBody => {
let idx = insertion_line.checked_sub(1)?; if idx >= file_lines.len() {
return None;
}
let trimmed = file_lines[idx].trim();
let delimiter = if trimmed.starts_with("\"\"\"") {
"\"\"\""
} else if trimmed.starts_with("'''") {
"'''"
} else {
return None;
};
if trimmed.len() > 6
&& trimmed.ends_with(delimiter)
&& trimmed[3..trimmed.len() - 3].contains(|c: char| !c.is_whitespace())
{
return Some(idx..idx + 1);
}
for (end_idx, line) in file_lines.iter().enumerate().skip(idx + 1) {
if line.trim().ends_with(delimiter) {
return Some(idx..end_idx + 1);
}
}
None
}
InsertionPosition::BeforeFunction => {
if insertion_line < 2 || file_lines.is_empty() {
return None;
}
let doc_prefix = if !format.line_prefix.is_empty() {
format.line_prefix.trim_end()
} else if !format.prefix.is_empty() {
format.prefix.trim_end()
} else {
return None;
};
let start_idx = insertion_line - 2; if start_idx >= file_lines.len() {
return None;
}
let trimmed = file_lines[start_idx].trim();
if !format.line_prefix.is_empty() {
if !trimmed.starts_with(doc_prefix) {
return None;
}
let mut top = start_idx;
while top > 0 {
let above = file_lines[top - 1].trim();
if above.starts_with(doc_prefix) {
top -= 1;
} else {
break;
}
}
Some(top..start_idx + 1)
} else {
let suffix = format.suffix.trim_end();
if !trimmed.ends_with(suffix) && !trimmed.starts_with(doc_prefix) {
return None;
}
let mut top = start_idx;
while top > 0 {
if file_lines[top].trim().starts_with(doc_prefix) {
break;
}
top -= 1;
}
if file_lines[top].trim().starts_with(doc_prefix) {
Some(top..start_idx + 1)
} else {
None
}
}
}
}
}
pub fn rewrite_file(
path: &Path,
edits: &[DocCommentResult],
parser: &Parser,
) -> Result<usize, DocWriterError> {
let _span = tracing::info_span!("rewrite_file", file = %path.display()).entered();
if edits.is_empty() {
return Ok(0);
}
let lock_path = path.with_extension("cqs-lock");
let lock_file = std::fs::OpenOptions::new()
.read(true)
.write(true)
.create(true)
.truncate(false)
.open(&lock_path)
.map_err(|e| {
DocWriterError::Io(std::io::Error::new(
e.kind(),
format!("{}: {}", lock_path.display(), e),
))
})?;
lock_file.lock()?;
let content = std::fs::read_to_string(path)?;
let file_lines: Vec<&str> = content.lines().collect();
let language = edits[0].language;
if edits.iter().any(|e| e.language != language) {
tracing::warn!(
file = %path.display(),
expected = %language,
"Mixed languages in doc edits for one file — using {}", language
);
}
let chunks = match parser.parse_source(&content, language, path) {
Ok(c) => c,
Err(e) => {
tracing::warn!(error = %e, file = %path.display(), "Failed to parse file for doc rewrite");
return Err(DocWriterError::Parser(e));
}
};
let mut resolved: Vec<ResolvedEdit> = Vec::new();
for edit in edits {
if edit.language != language {
continue;
}
let matching_chunks: Vec<_> = chunks
.iter()
.filter(|c| c.name == edit.function_name)
.collect();
let chunk = if matching_chunks.is_empty() {
tracing::warn!(
function = %edit.function_name,
file = %path.display(),
"Function not found in re-parsed file, skipping"
);
continue;
} else if matching_chunks.len() == 1 {
matching_chunks[0]
} else {
matching_chunks
.iter()
.min_by_key(|c| (c.line_start as isize - edit.line_start as isize).unsigned_abs())
.expect("matching_chunks guaranteed non-empty by else-if guard")
};
let line_start = chunk.line_start as usize;
let insertion_line = find_insertion_point(line_start, &file_lines, language);
let existing_range = detect_existing_doc_range(insertion_line, &file_lines, language);
if let Some(ref range) = existing_range {
let existing_doc: String = file_lines[range.clone()]
.iter()
.map(|l| l.trim())
.collect::<Vec<_>>()
.join("\n");
if existing_doc.len() >= 30 {
tracing::debug!(
function = %edit.function_name,
"Function already has adequate doc, skipping"
);
continue;
}
}
let chunk_line_idx = line_start.saturating_sub(1); let indent = if chunk_line_idx < file_lines.len() {
let line = file_lines[chunk_line_idx];
let stripped = line.trim_start();
&line[..line.len() - stripped.len()]
} else {
""
};
let format = doc_format_for(language);
let effective_indent = if format.position == InsertionPosition::InsideBody {
let body_idx = line_start; if body_idx < file_lines.len() && !file_lines[body_idx].trim().is_empty() {
let body_line = file_lines[body_idx];
let stripped = body_line.trim_start();
body_line[..body_line.len() - stripped.len()].to_string()
} else {
format!("{indent} ")
}
} else {
indent.to_string()
};
let formatted = format_doc_comment(
&edit.generated_doc,
language,
&effective_indent,
&edit.function_name,
);
if formatted.is_empty() {
continue;
}
let new_lines: Vec<String> = formatted.lines().map(|l| format!("{l}\n")).collect();
let insert_at_0 = insertion_line.saturating_sub(1);
tracing::debug!(
function = %edit.function_name,
insert_at = insertion_line,
existing_doc = existing_range.is_some(),
"Resolved doc edit"
);
resolved.push(ResolvedEdit {
remove_range: existing_range,
insert_at: insert_at_0,
new_lines,
});
}
let skipped = edits.len() - resolved.len();
if skipped > 0 {
tracing::info!(
file = %path.display(),
total = edits.len(),
skipped,
resolved = resolved.len(),
"Skipped doc edits (not found, adequate doc, or empty)"
);
}
if resolved.is_empty() {
return Ok(0);
}
resolved.sort_by(|a, b| b.insert_at.cmp(&a.insert_at));
let mut lines: Vec<String> = content.lines().map(|l| format!("{l}\n")).collect();
if content.ends_with('\n') && !lines.is_empty() {
} else if !content.ends_with('\n') && !lines.is_empty() {
if let Some(last) = lines.last_mut() {
if last.ends_with('\n') {
last.pop();
}
}
}
let count = resolved.len();
for edit in &resolved {
if let Some(ref range) = edit.remove_range {
if range.start < lines.len() {
let end = range.end.min(lines.len());
lines.drain(range.start..end);
}
}
let insert_at = if let Some(ref range) = edit.remove_range {
edit.insert_at
.saturating_sub(range.end.saturating_sub(range.start))
.min(lines.len())
} else {
edit.insert_at.min(lines.len())
};
for (i, line) in edit.new_lines.iter().enumerate() {
lines.insert(insert_at + i, line.clone());
}
}
let result_content: String = lines.concat();
atomic_write(path, result_content.as_bytes())?;
tracing::debug!(file = %path.display(), count, "Wrote doc comments");
Ok(count)
}
fn atomic_write(path: &Path, data: &[u8]) -> Result<(), std::io::Error> {
let dir = path.parent().unwrap_or(Path::new("."));
let suffix = crate::temp_suffix();
let temp_path = dir.join(format!(".cqs-doc-{}-{}.tmp", std::process::id(), suffix));
if let Err(e) = std::fs::write(&temp_path, data) {
let _ = std::fs::remove_file(&temp_path);
return Err(e);
}
match std::fs::rename(&temp_path, path) {
Ok(()) => Ok(()),
Err(rename_err) => {
let _ = std::fs::remove_file(&temp_path);
let backup_path = dir.join(format!(".cqs-doc-{}-{}.bak", std::process::id(), suffix));
let has_backup = if path.exists() {
std::fs::copy(path, &backup_path)
.map(|_| true)
.map_err(|e| {
tracing::warn!(
path = %path.display(),
error = %e,
"Cross-device fallback: failed to create backup"
);
e
})?
} else {
false
};
match std::fs::write(path, data) {
Ok(()) => {
if has_backup {
let _ = std::fs::remove_file(&backup_path);
}
Ok(())
}
Err(write_err) => {
if has_backup {
let _ = std::fs::rename(&backup_path, path);
}
tracing::warn!(
path = %path.display(),
rename_error = %rename_err,
write_error = %write_err,
"Atomic write failed: both rename and fallback write failed"
);
Err(write_err)
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use tempfile::NamedTempFile;
use crate::doc_writer::DocCommentResult;
use crate::language::Language;
fn make_edit(
file: &Path,
function_name: &str,
generated_doc: &str,
language: Language,
line_start: usize,
had_existing_doc: bool,
) -> DocCommentResult {
DocCommentResult {
file: file.to_path_buf(),
function_name: function_name.to_string(),
content_hash: "test_hash".to_string(),
generated_doc: generated_doc.to_string(),
language,
line_start,
had_existing_doc,
}
}
#[test]
fn test_insertion_point_plain_function() {
let lines = vec!["", "fn hello() {", "}", ""];
let point = find_insertion_point(2, &lines, Language::Rust);
assert_eq!(point, 2, "Should insert right before function");
}
#[test]
fn test_insertion_point_with_attributes() {
let lines = vec![
"use std::fmt;",
"",
"#[derive(Debug)]",
"#[cfg(test)]",
"fn hello() {",
"}",
];
let point = find_insertion_point(5, &lines, Language::Rust);
assert_eq!(point, 2, "Should insert above first attribute");
}
#[test]
fn test_insertion_point_python_inside_body() {
let lines = vec!["def hello():", " pass"];
let point = find_insertion_point(1, &lines, Language::Python);
assert_eq!(point, 2, "Should insert on line after def");
}
#[test]
fn test_insertion_point_with_at_decorator() {
let lines = vec![
"import os",
"",
"@staticmethod",
"@decorator",
"def hello():",
" pass",
];
let point = find_insertion_point(5, &lines, Language::Python);
assert_eq!(point, 6, "Python inserts inside body, ignores decorators");
}
#[test]
fn test_insertion_point_first_line_of_file() {
let lines = vec!["fn hello() {", "}"];
let point = find_insertion_point(1, &lines, Language::Rust);
assert_eq!(point, 1, "Should insert at line 1 when function is first");
}
#[test]
fn test_insertion_point_attribute_at_top_of_file() {
let lines = vec!["#[test]", "fn hello() {", "}"];
let point = find_insertion_point(2, &lines, Language::Rust);
assert_eq!(point, 1, "Should insert at line 1 above attribute at top");
}
#[test]
fn test_detect_no_existing_doc() {
let lines = vec!["use std::fmt;", "", "fn hello() {", "}"];
let range = detect_existing_doc_range(3, &lines, Language::Rust);
assert!(range.is_none(), "No doc comment should be detected");
}
#[test]
fn test_detect_rust_doc_comment() {
let lines = vec!["/// Does a thing.", "/// More detail.", "fn hello() {", "}"];
let range = detect_existing_doc_range(3, &lines, Language::Rust);
assert_eq!(range, Some(0..2), "Should detect two-line /// block");
}
#[test]
fn test_detect_single_line_rust_doc() {
let lines = vec!["/// Short.", "fn hello() {", "}"];
let range = detect_existing_doc_range(2, &lines, Language::Rust);
assert_eq!(range, Some(0..1), "Should detect single-line /// doc");
}
#[test]
fn test_detect_python_docstring_single_line() {
let lines = vec!["def hello():", " \"\"\"Does a thing.\"\"\"", " pass"];
let range = detect_existing_doc_range(2, &lines, Language::Python);
assert_eq!(range, Some(1..2), "Should detect single-line docstring");
}
#[test]
fn test_detect_python_docstring_multiline() {
let lines = vec![
"def hello():",
" \"\"\"",
" Does a thing.",
" \"\"\"",
" pass",
];
let range = detect_existing_doc_range(2, &lines, Language::Python);
assert_eq!(range, Some(1..4), "Should detect multi-line docstring");
}
#[test]
fn test_detect_no_python_docstring() {
let lines = vec!["def hello():", " pass"];
let range = detect_existing_doc_range(2, &lines, Language::Python);
assert!(range.is_none(), "No docstring present");
}
#[test]
fn test_rewrite_rust_undocumented_function() {
let source = "fn hello() {\n println!(\"hi\");\n}\n";
let mut tmp = NamedTempFile::with_suffix(".rs").unwrap();
write!(tmp, "{source}").unwrap();
tmp.flush().unwrap();
let parser = Parser::new().unwrap();
let edit = make_edit(
tmp.path(),
"hello",
"Prints a greeting.",
Language::Rust,
1,
false,
);
let count = rewrite_file(tmp.path(), &[edit], &parser).unwrap();
assert_eq!(count, 1);
let result = std::fs::read_to_string(tmp.path()).unwrap();
assert!(
result.contains("/// Prints a greeting."),
"Should contain doc comment, got:\n{result}"
);
assert!(
result.find("/// Prints a greeting.").unwrap() < result.find("fn hello()").unwrap(),
"Doc should appear before function"
);
}
#[test]
fn test_rewrite_rust_replace_thin_doc() {
let source = "/// Short\nfn hello() {\n println!(\"hi\");\n}\n";
let mut tmp = NamedTempFile::with_suffix(".rs").unwrap();
write!(tmp, "{source}").unwrap();
tmp.flush().unwrap();
let parser = Parser::new().unwrap();
let edit = make_edit(
tmp.path(),
"hello",
"Prints a friendly greeting to stdout.",
Language::Rust,
2,
true,
);
let count = rewrite_file(tmp.path(), &[edit], &parser).unwrap();
assert_eq!(count, 1);
let result = std::fs::read_to_string(tmp.path()).unwrap();
assert!(
!result.contains("/// Short"),
"Old thin doc should be removed, got:\n{result}"
);
assert!(
result.contains("/// Prints a friendly greeting to stdout."),
"New doc should be inserted, got:\n{result}"
);
}
#[test]
fn test_rewrite_rust_with_decorators() {
let source = "#[derive(Debug)]\n#[cfg(test)]\nfn hello() {\n println!(\"hi\");\n}\n";
let mut tmp = NamedTempFile::with_suffix(".rs").unwrap();
write!(tmp, "{source}").unwrap();
tmp.flush().unwrap();
let parser = Parser::new().unwrap();
let edit = make_edit(
tmp.path(),
"hello",
"Prints a greeting.",
Language::Rust,
3,
false,
);
let count = rewrite_file(tmp.path(), &[edit], &parser).unwrap();
assert_eq!(count, 1);
let result = std::fs::read_to_string(tmp.path()).unwrap();
let doc_pos = result.find("/// Prints a greeting.").unwrap();
let attr_pos = result.find("#[derive(Debug)]").unwrap();
let fn_pos = result.find("fn hello()").unwrap();
assert!(
doc_pos < attr_pos,
"Doc should be above #[derive], got:\n{result}"
);
assert!(
attr_pos < fn_pos,
"Attributes should be between doc and fn, got:\n{result}"
);
}
#[test]
fn test_rewrite_python_inside_body() {
let source = "def hello():\n pass\n";
let mut tmp = NamedTempFile::with_suffix(".py").unwrap();
write!(tmp, "{source}").unwrap();
tmp.flush().unwrap();
let parser = Parser::new().unwrap();
let edit = make_edit(
tmp.path(),
"hello",
"Prints a greeting.",
Language::Python,
1,
false,
);
let count = rewrite_file(tmp.path(), &[edit], &parser).unwrap();
assert_eq!(count, 1);
let result = std::fs::read_to_string(tmp.path()).unwrap();
let def_pos = result.find("def hello():").unwrap();
let doc_pos = result.find("\"\"\"").unwrap();
assert!(
doc_pos > def_pos,
"Docstring should be inside body (after def), got:\n{result}"
);
}
#[test]
fn test_rewrite_multiple_functions_bottom_up() {
let source = "\
fn alpha() {
println!(\"a\");
}
fn beta() {
println!(\"b\");
}
fn gamma() {
println!(\"c\");
}
";
let mut tmp = NamedTempFile::with_suffix(".rs").unwrap();
write!(tmp, "{source}").unwrap();
tmp.flush().unwrap();
let parser = Parser::new().unwrap();
let edits = vec![
make_edit(
tmp.path(),
"alpha",
"First function.",
Language::Rust,
1,
false,
),
make_edit(
tmp.path(),
"gamma",
"Third function.",
Language::Rust,
9,
false,
),
];
let count = rewrite_file(tmp.path(), &edits, &parser).unwrap();
assert_eq!(count, 2, "Should modify two functions");
let result = std::fs::read_to_string(tmp.path()).unwrap();
assert!(
result.contains("/// First function."),
"Alpha doc missing:\n{result}"
);
assert!(
result.contains("/// Third function."),
"Gamma doc missing:\n{result}"
);
let beta_pos = result.find("fn beta()").unwrap();
let before_beta = &result[..beta_pos];
assert!(
!before_beta.ends_with("/// "),
"Beta should not get a doc comment"
);
let alpha_doc = result.find("/// First function.").unwrap();
let alpha_fn = result.find("fn alpha()").unwrap();
let gamma_doc = result.find("/// Third function.").unwrap();
let gamma_fn = result.find("fn gamma()").unwrap();
assert!(alpha_doc < alpha_fn, "Alpha doc should be before alpha fn");
assert!(alpha_fn < gamma_doc, "Alpha fn should be before gamma doc");
assert!(gamma_doc < gamma_fn, "Gamma doc should be before gamma fn");
}
#[test]
fn test_rewrite_function_not_found() {
let source = "fn hello() {\n println!(\"hi\");\n}\n";
let mut tmp = NamedTempFile::with_suffix(".rs").unwrap();
write!(tmp, "{source}").unwrap();
tmp.flush().unwrap();
let parser = Parser::new().unwrap();
let edit = make_edit(
tmp.path(),
"nonexistent",
"This function does not exist.",
Language::Rust,
1,
false,
);
let count = rewrite_file(tmp.path(), &[edit], &parser).unwrap();
assert_eq!(count, 0, "Should return 0 when function not found");
let result = std::fs::read_to_string(tmp.path()).unwrap();
assert_eq!(result, source, "File should be unchanged");
}
#[test]
fn test_rewrite_disambiguates_same_name_functions() {
let source = "\
struct Alpha;
impl Alpha {
fn new() -> Self {
Alpha
}
}
struct Beta;
impl Beta {
fn new() -> Self {
Beta
}
}
";
let mut tmp = NamedTempFile::with_suffix(".rs").unwrap();
write!(tmp, "{source}").unwrap();
tmp.flush().unwrap();
let parser = Parser::new().unwrap();
let edit = make_edit(
tmp.path(),
"new",
"Creates a new Beta instance.",
Language::Rust,
13,
false,
);
let count = rewrite_file(tmp.path(), &[edit], &parser).unwrap();
assert_eq!(count, 1, "Should document exactly one function");
let result = std::fs::read_to_string(tmp.path()).unwrap();
let beta_pos = result.find("impl Beta").unwrap();
let doc_pos = result.find("Creates a new Beta").unwrap();
let alpha_pos = result.find("impl Alpha").unwrap();
assert!(
doc_pos > alpha_pos,
"Doc should not be near Alpha, got:\n{result}"
);
assert!(
doc_pos > beta_pos || doc_pos < beta_pos + 50,
"Doc should be near Beta impl, got:\n{result}"
);
}
#[test]
fn test_rewrite_skips_adequate_doc() {
let source = "/// This is a long enough doc comment for the function.\nfn hello() {\n println!(\"hi\");\n}\n";
let mut tmp = NamedTempFile::with_suffix(".rs").unwrap();
write!(tmp, "{source}").unwrap();
tmp.flush().unwrap();
let parser = Parser::new().unwrap();
let edit = make_edit(
tmp.path(),
"hello",
"Replacement doc that should not appear.",
Language::Rust,
2,
true,
);
let count = rewrite_file(tmp.path(), &[edit], &parser).unwrap();
assert_eq!(count, 0, "Should skip function with adequate doc");
let result = std::fs::read_to_string(tmp.path()).unwrap();
assert!(
!result.contains("Replacement"),
"Original doc should be preserved, got:\n{result}"
);
assert!(
result.contains("This is a long enough"),
"Original doc should remain"
);
}
#[test]
fn test_rewrite_empty_edits_returns_zero() {
let source = "fn hello() {\n println!(\"hi\");\n}\n";
let mut tmp = NamedTempFile::with_suffix(".rs").unwrap();
write!(tmp, "{source}").unwrap();
tmp.flush().unwrap();
let parser = Parser::new().unwrap();
let count = rewrite_file(tmp.path(), &[], &parser).unwrap();
assert_eq!(count, 0, "Empty edits should return 0");
let result = std::fs::read_to_string(tmp.path()).unwrap();
assert_eq!(result, source, "File should be unchanged with empty edits");
}
}