use crate::error::Result;
use crate::traits::{Repair, RepairStrategy, Validator};
use regex::Regex;
use std::sync::OnceLock;
#[allow(dead_code)]
struct DiffRegexCache {
hunk_header: Regex,
file_header: Regex,
context_line: Regex,
added_line: Regex,
removed_line: Regex,
malformed_hunk: Regex,
missing_newline: Regex,
}
impl DiffRegexCache {
fn new() -> Result<Self> {
Ok(Self {
hunk_header: Regex::new(r"^@@\s+-(\d+)(?:,(\d+))?\s+\+(\d+)(?:,(\d+))?\s+@@")?,
file_header: Regex::new(r"^(?:---|\+\+\+)\s+[^\s]+")?,
context_line: Regex::new(r"^ (.*)$")?,
added_line: Regex::new(r"^\+(.*)$")?,
removed_line: Regex::new(r"^-(.*)$")?,
malformed_hunk: Regex::new(r"^@@[^@]*$")?,
missing_newline: Regex::new(r".$")?,
})
}
}
static DIFF_REGEX_CACHE: OnceLock<DiffRegexCache> = OnceLock::new();
fn get_diff_regex_cache() -> &'static DiffRegexCache {
DIFF_REGEX_CACHE
.get_or_init(|| DiffRegexCache::new().expect("Failed to initialize diff regex cache"))
}
pub struct DiffRepairer {
inner: crate::repairer_base::GenericRepairer,
}
impl DiffRepairer {
pub fn new() -> Self {
let strategies: Vec<Box<dyn RepairStrategy>> = vec![
Box::new(FixMissingHunkHeadersStrategy),
Box::new(FixLinePrefixesStrategy),
Box::new(FixMissingNewlinesStrategy),
Box::new(FixMalformedHunkRangesStrategy),
Box::new(FixMissingFileHeadersStrategy),
Box::new(FixInconsistentSpacingStrategy),
];
let validator: Box<dyn Validator> = Box::new(DiffValidator);
let inner = crate::repairer_base::GenericRepairer::new(validator, strategies);
Self { inner }
}
}
impl Default for DiffRepairer {
fn default() -> Self {
Self::new()
}
}
impl Repair for DiffRepairer {
fn repair(&mut self, content: &str) -> Result<String> {
let mut repaired = self.inner.repair(content)?;
if !repaired.is_empty() && !repaired.ends_with('\n') && !repaired.ends_with("\r\n") {
repaired.push('\n');
}
Ok(repaired)
}
fn needs_repair(&self, content: &str) -> bool {
self.inner.needs_repair(content)
}
fn confidence(&self, content: &str) -> f64 {
if content.trim().is_empty() {
return 0.0;
}
let mut score: f64 = 0.0;
let lines: Vec<&str> = content.lines().collect();
if lines.is_empty() {
return 0.0;
}
let hunk_count = lines.iter().filter(|line| line.starts_with("@@")).count();
if hunk_count > 0 {
score += 0.3;
}
let file_header_count = lines
.iter()
.filter(|line| line.starts_with("---") || line.starts_with("+++"))
.count();
if file_header_count > 0 {
score += 0.2;
}
let diff_line_count = lines
.iter()
.filter(|line| {
line.starts_with('+')
|| line.starts_with('-')
|| (line.starts_with(' ')
&& !line.starts_with("@@")
&& !line.starts_with("---")
&& !line.starts_with("+++"))
})
.count();
if diff_line_count > 0 {
score += 0.3;
}
let valid_hunk_count = lines
.iter()
.filter(|line| {
line.starts_with("@@")
&& line.contains("@@")
&& get_diff_regex_cache().hunk_header.is_match(line)
})
.count();
if valid_hunk_count > 0 && hunk_count > 0 {
score += 0.2;
}
score.min(1.0)
}
}
pub struct DiffValidator;
impl Validator for DiffValidator {
fn is_valid(&self, content: &str) -> bool {
if content.trim().is_empty() {
return false;
}
let lines: Vec<&str> = content.lines().collect();
if lines.is_empty() {
return false;
}
let has_hunk = lines.iter().any(|line| {
line.starts_with("@@")
&& line.contains("@@")
&& get_diff_regex_cache().hunk_header.is_match(line)
});
if !has_hunk {
return false;
}
let has_file_headers = lines.iter().any(|line| line.starts_with("---"))
&& lines.iter().any(|line| line.starts_with("+++"));
if !has_file_headers {
return false;
}
let mut in_hunk = false;
for line in &lines {
if line.starts_with("@@") {
in_hunk = true;
if !get_diff_regex_cache().hunk_header.is_match(line) {
return false;
}
if line.contains(" ") {
return false;
}
} else if in_hunk {
if !line.starts_with('+')
&& !line.starts_with('-')
&& !line.starts_with(' ')
&& !line.trim().is_empty()
&& !line.starts_with("---")
&& !line.starts_with("+++")
{
return false;
}
}
}
true
}
fn validate(&self, content: &str) -> Vec<String> {
let mut errors = Vec::new();
if content.trim().is_empty() {
errors.push("Empty diff content".to_string());
return errors;
}
let lines: Vec<&str> = content.lines().collect();
let hunk_count = lines.iter().filter(|line| line.starts_with("@@")).count();
if hunk_count == 0 {
errors.push("No hunk headers found (expected lines starting with @@)".to_string());
}
for (line_num, line) in lines.iter().enumerate() {
if line.starts_with("@@")
&& !get_diff_regex_cache().hunk_header.is_match(line) {
errors.push(format!(
"Invalid hunk header at line {}: {}",
line_num + 1,
line
));
}
}
let mut in_hunk = false;
for (line_num, line) in lines.iter().enumerate() {
if line.starts_with("@@") {
in_hunk = true;
} else if in_hunk && !line.trim().is_empty()
&& !line.starts_with('+')
&& !line.starts_with('-')
&& !line.starts_with(' ')
&& !line.starts_with("---")
&& !line.starts_with("+++")
{
errors.push(format!(
"Invalid diff line prefix at line {}: expected +, -, or space",
line_num + 1
));
}
}
errors
}
}
struct FixMissingHunkHeadersStrategy;
impl RepairStrategy for FixMissingHunkHeadersStrategy {
fn apply(&self, content: &str) -> Result<String> {
let lines: Vec<&str> = content.lines().collect();
let mut result = Vec::new();
let mut has_hunk = false;
for line in &lines {
if line.starts_with("@@") {
has_hunk = true;
break;
}
}
if !has_hunk && lines.len() > 2 {
let mut diff_lines = 0;
for line in &lines {
if line.starts_with('+')
|| line.starts_with('-')
|| (line.starts_with(' ')
&& !line.starts_with("---")
&& !line.starts_with("+++"))
{
diff_lines += 1;
}
}
if diff_lines > 0 {
let old_start = 1;
let new_start = 1;
let mut old_count = 0;
let mut new_count = 0;
for line in &lines {
if line.starts_with('-') {
old_count += 1;
} else if line.starts_with('+') {
new_count += 1;
} else if line.starts_with(' ') {
old_count += 1;
new_count += 1;
}
}
if old_count > 0 || new_count > 0 {
result.push(format!(
"@@ -{},{} +{},{} @@",
old_start,
old_count.max(1),
new_start,
new_count.max(1)
));
}
}
}
for line in lines {
result.push(line.to_string());
}
Ok(result.join("\n"))
}
fn priority(&self) -> u8 {
10
}
fn name(&self) -> &str {
"FixMissingHunkHeaders"
}
}
struct FixLinePrefixesStrategy;
impl RepairStrategy for FixLinePrefixesStrategy {
fn apply(&self, content: &str) -> Result<String> {
let lines: Vec<&str> = content.lines().collect();
let mut result = Vec::new();
let mut in_hunk = false;
for line in lines {
if line.starts_with("@@") {
in_hunk = true;
result.push(line.to_string());
} else if line.starts_with("---") || line.starts_with("+++") {
in_hunk = false;
result.push(line.to_string());
} else if in_hunk {
let trimmed = line.trim();
if trimmed.is_empty() {
result.push("".to_string());
} else if trimmed.starts_with('+')
|| trimmed.starts_with('-')
|| trimmed.starts_with(' ')
{
if trimmed.starts_with('+') && !line.starts_with('+') {
result.push(format!("+{}", &trimmed[1..]));
} else if trimmed.starts_with('-') && !line.starts_with('-') {
result.push(format!("-{}", &trimmed[1..]));
} else if trimmed.starts_with(' ') && !line.starts_with(' ') {
result.push(format!(" {}", trimmed));
} else {
result.push(line.to_string());
}
} else {
result.push(format!(" {}", trimmed));
}
} else {
result.push(line.to_string());
}
}
Ok(result.join("\n"))
}
fn priority(&self) -> u8 {
8
}
fn name(&self) -> &str {
"FixLinePrefixes"
}
}
struct FixMissingNewlinesStrategy;
impl RepairStrategy for FixMissingNewlinesStrategy {
fn apply(&self, content: &str) -> Result<String> {
let mut result = content.to_string();
if !result.ends_with('\n') && !result.ends_with("\r\n") {
result.push('\n');
}
Ok(result)
}
fn priority(&self) -> u8 {
5
}
fn name(&self) -> &str {
"FixMissingNewlines"
}
}
struct FixMalformedHunkRangesStrategy;
impl RepairStrategy for FixMalformedHunkRangesStrategy {
fn apply(&self, content: &str) -> Result<String> {
let lines: Vec<&str> = content.lines().collect();
let mut result = Vec::new();
for line in lines {
if line.starts_with("@@") {
let hunk_regex = &get_diff_regex_cache().hunk_header;
if !hunk_regex.is_match(line) {
let numbers: Vec<i32> = line
.split(|c: char| !c.is_ascii_digit() && c != '-')
.filter_map(|s| s.parse().ok())
.collect();
if numbers.len() >= 2 {
let old_start = numbers[0];
let old_count = if numbers.len() > 2 { numbers[1] } else { 1 };
let new_start = if numbers.len() > 2 {
numbers[2]
} else {
numbers[1]
};
let new_count = if numbers.len() > 3 { numbers[3] } else { 1 };
result.push(format!(
"@@ -{},{} +{},{} @@",
old_start, old_count, new_start, new_count
));
} else {
result.push("@@ -1,1 +1,1 @@".to_string());
}
} else {
result.push(line.to_string());
}
} else {
result.push(line.to_string());
}
}
Ok(result.join("\n"))
}
fn priority(&self) -> u8 {
7
}
fn name(&self) -> &str {
"FixMalformedHunkRanges"
}
}
struct FixMissingFileHeadersStrategy;
impl RepairStrategy for FixMissingFileHeadersStrategy {
fn apply(&self, content: &str) -> Result<String> {
let lines: Vec<&str> = content.lines().collect();
if lines.is_empty() {
return Ok(content.to_string());
}
let mut result = Vec::new();
let mut has_file_header = false;
for line in &lines {
if line.starts_with("---") || line.starts_with("+++") {
has_file_header = true;
break;
}
}
if !has_file_header {
let mut found_hunk = false;
for line in &lines {
if line.starts_with("@@") && !found_hunk {
result.push("--- a/file".to_string());
result.push("+++ b/file".to_string());
found_hunk = true;
}
result.push(line.to_string());
}
if !found_hunk {
result.insert(0, "+++ b/file".to_string());
result.insert(0, "--- a/file".to_string());
}
} else {
for line in lines {
result.push(line.to_string());
}
}
Ok(result.join("\n"))
}
fn priority(&self) -> u8 {
6
}
fn name(&self) -> &str {
"FixMissingFileHeaders"
}
}
struct FixInconsistentSpacingStrategy;
impl RepairStrategy for FixInconsistentSpacingStrategy {
fn apply(&self, content: &str) -> Result<String> {
let lines: Vec<&str> = content.lines().collect();
let mut result = Vec::new();
for line in lines {
if line.starts_with("@@") {
let mut normalized = line.to_string();
while normalized.contains(" ") {
normalized = normalized.replace(" ", " ");
}
normalized = normalized.trim().to_string();
result.push(normalized);
} else if line.starts_with("---") || line.starts_with("+++") {
let normalized = line.trim().to_string();
result.push(normalized);
} else {
result.push(line.to_string());
}
}
Ok(result.join("\n"))
}
fn priority(&self) -> u8 {
4
}
fn name(&self) -> &str {
"FixInconsistentSpacing"
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_diff_validator() {
let validator = DiffValidator;
let valid_diff = r#"--- a/file.txt
+++ b/file.txt
@@ -1,3 +1,3 @@
line1
-line2
+line2_modified
line3
"#;
assert!(validator.is_valid(valid_diff));
let invalid_diff = r#"--- a/file.txt
+++ b/file.txt
line1
-line2
+line2_modified
"#;
assert!(!validator.is_valid(invalid_diff));
}
#[test]
fn test_diff_repairer_basic() {
let mut repairer = DiffRepairer::new();
let valid = r#"--- a/file.txt
+++ b/file.txt
@@ -1,1 +1,1 @@
-old
+new
"#;
let result = repairer.repair(valid).unwrap();
assert!(result.contains("@@"));
}
#[test]
fn test_fix_missing_hunk_headers() {
let strategy = FixMissingHunkHeadersStrategy;
let content = r#"--- a/file.txt
+++ b/file.txt
-old
+new
"#;
let result = strategy.apply(content).unwrap();
assert!(result.contains("@@"));
}
#[test]
fn test_fix_line_prefixes() {
let strategy = FixLinePrefixesStrategy;
let content = r#"@@ -1,1 +1,1 @@
old
new
"#;
let result = strategy.apply(content).unwrap();
assert!(result.contains(" -") || result.contains(" +") || result.contains(" "));
}
#[test]
fn test_fix_missing_file_headers() {
let strategy = FixMissingFileHeadersStrategy;
let content = r#"@@ -1,1 +1,1 @@
-old
+new
"#;
let result = strategy.apply(content).unwrap();
assert!(result.contains("---") || result.contains("+++"));
}
}