use std::path::{Path, PathBuf};
use anyhow::{Context, Result, anyhow};
use serde::{Deserialize, Serialize};
use similar::TextDiff;
use tree_sitter::{Language, Node, Parser};
use uuid::Uuid;
use crate::symgraph::symbol::{detect_language, extract_symbols};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Patch {
pub id: String,
pub file: PathBuf,
pub original: String,
pub modified: String,
pub diff: String,
}
pub fn replace_symbol(file: &Path, name: &str, new_source: &str) -> Result<Patch> {
let source = std::fs::read_to_string(file)
.with_context(|| format!("failed to read {}", file.display()))?;
let (lang, _) = detect_language(file)
.with_context(|| format!("unsupported file extension: {}", file.display()))?;
let symbols = extract_symbols(&source, lang.clone(), file);
let sym = symbols
.into_iter()
.find(|s| s.name == name)
.ok_or_else(|| anyhow!("symbol '{name}' not found in {}", file.display()))?;
let mut modified = String::with_capacity(source.len() + new_source.len());
modified.push_str(&source[..sym.start_byte]);
modified.push_str(new_source);
modified.push_str(&source[sym.end_byte..]);
if let Err(e) = validate_syntax(&modified, lang) {
return Err(anyhow!(
"modified source has syntax errors after replacing '{name}': {e}"
));
}
let diff = emit_diff(
&source,
&modified,
file.file_name().and_then(|s| s.to_str()).unwrap_or("file"),
);
Ok(Patch {
id: Uuid::new_v4().to_string(),
file: file.to_path_buf(),
original: source,
modified,
diff,
})
}
pub fn insert_after_symbol(file: &Path, anchor: &str, new_source: &str) -> Result<Patch> {
let source = std::fs::read_to_string(file)
.with_context(|| format!("failed to read {}", file.display()))?;
let (lang, _) = detect_language(file)
.with_context(|| format!("unsupported file extension: {}", file.display()))?;
let symbols = extract_symbols(&source, lang.clone(), file);
let sym = symbols
.into_iter()
.find(|s| s.name == anchor)
.ok_or_else(|| anyhow!("anchor symbol '{anchor}' not found in {}", file.display()))?;
let mut modified = String::with_capacity(source.len() + new_source.len() + 2);
modified.push_str(&source[..sym.end_byte]);
modified.push_str("\n\n");
modified.push_str(new_source);
modified.push_str(&source[sym.end_byte..]);
if let Err(e) = validate_syntax(&modified, lang) {
return Err(anyhow!(
"modified source has syntax errors after inserting after '{anchor}': {e}"
));
}
let diff = emit_diff(
&source,
&modified,
file.file_name().and_then(|s| s.to_str()).unwrap_or("file"),
);
Ok(Patch {
id: Uuid::new_v4().to_string(),
file: file.to_path_buf(),
original: source,
modified,
diff,
})
}
pub fn add_import(file: &Path, import_stmt: &str) -> Result<Patch> {
let source = std::fs::read_to_string(file)
.with_context(|| format!("failed to read {}", file.display()))?;
let (lang, lang_tag) = detect_language(file)
.with_context(|| format!("unsupported file extension: {}", file.display()))?;
if source.contains(import_stmt.trim()) {
let diff = emit_diff(
&source,
&source,
file.file_name().and_then(|s| s.to_str()).unwrap_or("file"),
);
return Ok(Patch {
id: Uuid::new_v4().to_string(),
file: file.to_path_buf(),
original: source.clone(),
modified: source,
diff,
});
}
let import_prefix: &[&str] = match lang_tag {
"rust" => &["use "],
"python" => &["import ", "from "],
"javascript" => &["import "],
"go" => &["import "],
_ => &[],
};
let mut insert_at: usize = 0;
let mut byte_pos: usize = 0;
for line in source.split_inclusive('\n') {
let trimmed = line.trim_start();
if import_prefix.iter().any(|p| trimmed.starts_with(p)) {
insert_at = byte_pos + line.len();
}
byte_pos += line.len();
}
let mut to_insert = String::new();
to_insert.push_str(import_stmt.trim_end());
to_insert.push('\n');
let mut modified = String::with_capacity(source.len() + to_insert.len());
modified.push_str(&source[..insert_at]);
modified.push_str(&to_insert);
modified.push_str(&source[insert_at..]);
if let Err(e) = validate_syntax(&modified, lang) {
return Err(anyhow!(
"modified source has syntax errors after adding import '{import_stmt}': {e}"
));
}
let diff = emit_diff(
&source,
&modified,
file.file_name().and_then(|s| s.to_str()).unwrap_or("file"),
);
Ok(Patch {
id: Uuid::new_v4().to_string(),
file: file.to_path_buf(),
original: source,
modified,
diff,
})
}
pub fn validate_syntax(source: &str, lang: Language) -> Result<(), String> {
let mut parser = Parser::new();
parser
.set_language(&lang)
.map_err(|e| format!("set_language: {e}"))?;
let tree = parser
.parse(source, None)
.ok_or_else(|| "parser returned no tree".to_string())?;
let root = tree.root_node();
if !root.has_error() {
return Ok(());
}
let mut errors: Vec<String> = Vec::new();
collect_errors(root, source.as_bytes(), &mut errors);
if errors.is_empty() {
return Err("parse tree contains errors".to_string());
}
Err(errors.join("; "))
}
fn collect_errors(node: Node, _bytes: &[u8], out: &mut Vec<String>) {
if out.len() >= 5 {
return;
}
if node.is_error() || node.is_missing() {
let pos = node.start_position();
out.push(format!(
"syntax error at line {}, col {} ({})",
pos.row + 1,
pos.column + 1,
node.kind()
));
}
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
collect_errors(child, _bytes, out);
}
}
pub fn emit_diff(original: &str, modified: &str, filename: &str) -> String {
let diff = TextDiff::from_lines(original, modified);
let mut out = String::new();
let header = format!("--- a/{filename}\n+++ b/{filename}\n");
out.push_str(&header);
for hunk in diff.unified_diff().context_radius(3).iter_hunks() {
out.push_str(&format!("{hunk}"));
}
out
}
pub fn apply_patch(patch: &Patch) -> Result<()> {
std::fs::write(&patch.file, &patch.modified)
.with_context(|| format!("failed to write {}", patch.file.display()))?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use tempfile::tempdir;
fn write_tmp(dir: &Path, name: &str, body: &str) -> PathBuf {
let p = dir.join(name);
let mut f = std::fs::File::create(&p).unwrap();
f.write_all(body.as_bytes()).unwrap();
p
}
#[test]
fn validate_syntax_ok() {
let src = "fn main() { let x = 1; }\n";
assert!(validate_syntax(src, tree_sitter_rust::LANGUAGE.into()).is_ok());
}
#[test]
fn validate_syntax_err() {
let src = "fn main( { let x = ; }\n";
let r = validate_syntax(src, tree_sitter_rust::LANGUAGE.into());
assert!(r.is_err(), "expected syntax error, got {:?}", r);
}
#[test]
fn emit_diff_contains_plus_minus() {
let a = "line one\nline two\nline three\n";
let b = "line one\nline TWO\nline three\n";
let d = emit_diff(a, b, "test.txt");
assert!(d.contains("-line two"), "diff missing - line: {d}");
assert!(d.contains("+line TWO"), "diff missing + line: {d}");
}
#[test]
fn replace_symbol_round_trips() {
let dir = tempdir().unwrap();
let path = write_tmp(
dir.path(),
"x.rs",
"fn foo() -> i32 { 1 }\n\nfn bar() -> i32 { 2 }\n",
);
let patch = replace_symbol(&path, "foo", "fn foo() -> i32 { 42 }").unwrap();
assert_ne!(patch.original, patch.modified);
assert!(patch.modified.contains("42"));
assert!(!patch.diff.is_empty());
assert!(validate_syntax(&patch.modified, tree_sitter_rust::LANGUAGE.into()).is_ok());
}
#[test]
fn add_import_skips_duplicates() {
let dir = tempdir().unwrap();
let body = "use std::io;\n\nfn main() {}\n";
let path = write_tmp(dir.path(), "x.rs", body);
let patch = add_import(&path, "use std::io;").unwrap();
assert_eq!(
patch.original, patch.modified,
"duplicate import should noop"
);
}
#[test]
fn add_import_inserts_after_existing() {
let dir = tempdir().unwrap();
let body = "use std::io;\n\nfn main() {}\n";
let path = write_tmp(dir.path(), "x.rs", body);
let patch = add_import(&path, "use std::fs;").unwrap();
assert!(patch.modified.contains("use std::fs;"));
assert!(patch.modified.contains("use std::io;"));
let pos_io = patch.modified.find("use std::io;").unwrap();
let pos_fs = patch.modified.find("use std::fs;").unwrap();
assert!(pos_fs > pos_io, "fs should be inserted after io");
}
}