use crate::ai::{AiClient, AiError, AiResult, FixPromptBuilder, Message, PromptTemplate};
use crate::models::{Finding, Severity};
use regex::Regex;
use serde::{Deserialize, Serialize};
use std::fs;
use std::path::{Path, PathBuf};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum FixType {
Refactor,
Simplify,
Extract,
Rename,
Remove,
Security,
TypeHint,
Documentation,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum FixConfidence {
High,
Medium,
Low,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CodeChange {
pub file_path: PathBuf,
pub original_code: String,
pub fixed_code: String,
pub start_line: u32,
pub end_line: u32,
pub description: String,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct Evidence {
pub similar_patterns: Vec<String>,
pub documentation_refs: Vec<String>,
pub best_practices: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FixProposal {
pub id: String,
pub finding_id: String,
pub fix_type: FixType,
pub confidence: FixConfidence,
pub title: String,
pub description: String,
pub rationale: String,
pub changes: Vec<CodeChange>,
pub evidence: Evidence,
pub syntax_valid: bool,
}
impl FixProposal {
pub fn diff(&self, repo_path: &Path) -> String {
let mut diff = String::new();
for change in &self.changes {
diff.push_str(&format!("--- a/{}\n", change.file_path.display()));
diff.push_str(&format!("+++ b/{}\n", change.file_path.display()));
diff.push_str(&format!(
"@@ -{},{} +{},{} @@\n",
change.start_line,
change.end_line - change.start_line + 1,
change.start_line,
change.fixed_code.lines().count()
));
for line in change.original_code.lines() {
diff.push_str(&format!("-{}\n", line));
}
for line in change.fixed_code.lines() {
diff.push_str(&format!("+{}\n", line));
}
diff.push('\n');
}
diff
}
pub fn apply(&self, repo_path: &Path) -> AiResult<()> {
for change in &self.changes {
let file_path = repo_path.join(&change.file_path);
let content = fs::read_to_string(&file_path)?;
let new_content = content.replace(&change.original_code, &change.fixed_code);
if new_content == content {
return Err(AiError::ParseError(format!(
"Original code not found in {}",
change.file_path.display()
)));
}
fs::write(&file_path, new_content)?;
}
Ok(())
}
}
pub struct FixGenerator {
client: AiClient,
}
impl FixGenerator {
pub fn new(client: AiClient) -> Self {
Self { client }
}
pub async fn generate_fix(
&self,
finding: &Finding,
repo_path: &Path,
) -> AiResult<FixProposal> {
let language = finding
.affected_files
.first()
.and_then(|p| p.extension())
.and_then(|e| e.to_str())
.map(extension_to_language)
.unwrap_or("python");
let fix_type = determine_fix_type(finding);
let code_section = self.read_code_section(finding, repo_path)?;
let prompt = FixPromptBuilder::new(finding.clone(), fix_type, language)
.code_section(&code_section)
.build();
let system_prompt = PromptTemplate::system_prompt(language);
let response = self
.client
.generate(vec![Message::user(prompt)], Some(system_prompt))
.await?;
let mut fix = self.parse_response(&response, finding, fix_type)?;
fix.syntax_valid = self.validate_syntax(&fix, language);
if !self.validate_original_code(&fix, repo_path) {
fix.confidence = FixConfidence::Low;
}
Ok(fix)
}
pub async fn generate_fix_with_retry(
&self,
finding: &Finding,
repo_path: &Path,
max_retries: u32,
) -> AiResult<FixProposal> {
let mut last_errors: Vec<String> = Vec::new();
for attempt in 0..=max_retries {
let language = finding
.affected_files
.first()
.and_then(|p| p.extension())
.and_then(|e| e.to_str())
.map(extension_to_language)
.unwrap_or("python");
let fix_type = determine_fix_type(finding);
let code_section = self.read_code_section(finding, repo_path)?;
let mut builder = FixPromptBuilder::new(finding.clone(), fix_type, language)
.code_section(&code_section);
if attempt > 0 && !last_errors.is_empty() {
builder = builder.previous_errors(last_errors.clone());
}
let prompt = builder.build();
let system_prompt = PromptTemplate::system_prompt(language);
let response = self
.client
.generate(vec![Message::user(prompt)], Some(system_prompt))
.await?;
let mut fix = self.parse_response(&response, finding, fix_type)?;
fix.syntax_valid = self.validate_syntax(&fix, language);
let mut errors = Vec::new();
if !fix.syntax_valid {
errors.push("SyntaxError: generated code has syntax errors".to_string());
}
if !self.validate_original_code(&fix, repo_path) {
errors.push("MatchError: Original code not found in file".to_string());
}
if errors.is_empty() {
return Ok(fix);
}
last_errors = errors;
}
self.generate_fix(finding, repo_path).await
}
fn read_code_section(&self, finding: &Finding, repo_path: &Path) -> AiResult<String> {
let file_path = finding
.affected_files
.first()
.ok_or_else(|| AiError::ParseError("No affected files".to_string()))?;
let full_path = repo_path.join(file_path);
let content = fs::read_to_string(&full_path)?;
let lines: Vec<&str> = content.lines().collect();
let start = finding.line_start.unwrap_or(1).saturating_sub(10) as usize;
let end = finding
.line_end
.or(finding.line_start)
.unwrap_or(1)
.saturating_add(20) as usize;
let start = start.min(lines.len());
let end = end.min(lines.len());
Ok(lines[start..end].join("\n"))
}
fn parse_response(
&self,
response: &str,
finding: &Finding,
fix_type: FixType,
) -> AiResult<FixProposal> {
let json_regex = Regex::new(r"```json\s*(\{.*?\})\s*```").unwrap();
let json_str = json_regex
.captures(response)
.and_then(|c| c.get(1))
.map(|m| m.as_str())
.unwrap_or(response);
let data: serde_json::Value = serde_json::from_str(json_str).map_err(|e| {
AiError::ParseError(format!("Failed to parse JSON response: {}", e))
})?;
let changes: Vec<CodeChange> = data
.get("changes")
.and_then(|c| c.as_array())
.map(|arr| {
arr.iter()
.filter_map(|change| {
Some(CodeChange {
file_path: PathBuf::from(
change.get("file_path")?.as_str()?,
),
original_code: change.get("original_code")?.as_str()?.to_string(),
fixed_code: change.get("fixed_code")?.as_str()?.to_string(),
start_line: change.get("start_line")?.as_u64()? as u32,
end_line: change.get("end_line")?.as_u64()? as u32,
description: change
.get("description")
.and_then(|d| d.as_str())
.unwrap_or("")
.to_string(),
})
})
.collect()
})
.unwrap_or_default();
let evidence = data.get("evidence").map(|e| Evidence {
similar_patterns: extract_string_array(e, "similar_patterns"),
documentation_refs: extract_string_array(e, "documentation_refs"),
best_practices: extract_string_array(e, "best_practices"),
}).unwrap_or_default();
let confidence = calculate_confidence(&data, finding, &changes);
let fix_id = format!(
"{:x}",
md5::compute(format!(
"{}:{}:{}",
finding.id,
finding.line_start.unwrap_or(0),
chrono::Utc::now().timestamp()
))
)[..12]
.to_string();
Ok(FixProposal {
id: fix_id,
finding_id: finding.id.clone(),
fix_type,
confidence,
title: data
.get("title")
.and_then(|t| t.as_str())
.unwrap_or("Auto-generated fix")
.to_string(),
description: data
.get("description")
.and_then(|d| d.as_str())
.unwrap_or("")
.to_string(),
rationale: data
.get("rationale")
.and_then(|r| r.as_str())
.unwrap_or("")
.to_string(),
changes,
evidence,
syntax_valid: false, })
}
fn validate_syntax(&self, fix: &FixProposal, language: &str) -> bool {
for change in &fix.changes {
let code = &change.fixed_code;
match language {
"python" => {
if code.contains("def ") && !code.contains(':') {
return false;
}
if code.matches('(').count() != code.matches(')').count() {
return false;
}
if code.matches('[').count() != code.matches(']').count() {
return false;
}
}
"javascript" | "typescript" => {
if code.matches('{').count() != code.matches('}').count() {
return false;
}
}
"rust" | "go" | "java" => {
if code.matches('{').count() != code.matches('}').count() {
return false;
}
}
_ => {}
}
}
true
}
fn validate_original_code(&self, fix: &FixProposal, repo_path: &Path) -> bool {
for change in &fix.changes {
let file_path = repo_path.join(&change.file_path);
if let Ok(content) = fs::read_to_string(&file_path) {
if content.contains(&change.original_code) {
continue;
}
let normalized_original: String = change
.original_code
.lines()
.map(|l| l.trim())
.filter(|l| !l.is_empty())
.collect::<Vec<_>>()
.join("\n");
let normalized_content: String = content
.lines()
.map(|l| l.trim())
.filter(|l| !l.is_empty())
.collect::<Vec<_>>()
.join("\n");
if !normalized_content.contains(&normalized_original) {
return false;
}
} else {
return false;
}
}
true
}
}
fn extract_string_array(value: &serde_json::Value, key: &str) -> Vec<String> {
value
.get(key)
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(|s| s.to_string()))
.collect()
})
.unwrap_or_default()
}
fn determine_fix_type(finding: &Finding) -> FixType {
let title = finding.title.to_lowercase();
let description = finding.description.to_lowercase();
if finding.severity == Severity::Critical || title.contains("security") {
return FixType::Security;
}
if title.contains("complex") || description.contains("cyclomatic") {
return FixType::Simplify;
}
if title.contains("unused") || title.contains("dead code") {
return FixType::Remove;
}
if title.contains("docstring") || title.contains("documentation") {
return FixType::Documentation;
}
if title.contains("type") && description.contains("hint") {
return FixType::TypeHint;
}
if title.contains("long") || title.contains("too many") {
return FixType::Extract;
}
FixType::Refactor
}
fn calculate_confidence(
data: &serde_json::Value,
finding: &Finding,
changes: &[CodeChange],
) -> FixConfidence {
let mut score = 0.5;
if changes.len() == 1 {
score += 0.1;
}
if let Some(rationale) = data.get("rationale").and_then(|r| r.as_str()) {
if rationale.len() > 100 {
score += 0.1;
}
}
if finding.severity == Severity::Critical {
score -= 0.2;
}
if let Some(evidence) = data.get("evidence") {
if evidence.get("best_practices").and_then(|b| b.as_array()).map(|a| !a.is_empty()).unwrap_or(false) {
score += 0.1;
}
}
if score >= 0.9 {
FixConfidence::High
} else if score >= 0.7 {
FixConfidence::Medium
} else {
FixConfidence::Low
}
}
fn extension_to_language(ext: &str) -> &'static str {
match ext {
"py" => "python",
"js" => "javascript",
"ts" | "tsx" => "typescript",
"rs" => "rust",
"go" => "go",
"java" => "java",
_ => "python",
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_determine_fix_type() {
let mut finding = Finding {
id: "test".to_string(),
detector: "test".to_string(),
severity: Severity::Medium,
title: "High cyclomatic complexity".to_string(),
description: "Function has high complexity".to_string(),
affected_files: vec![],
line_start: None,
line_end: None,
suggested_fix: None,
estimated_effort: None,
category: None,
cwe_id: None,
why_it_matters: None,
..Default::default()
};
assert_eq!(determine_fix_type(&finding), FixType::Simplify);
finding.title = "Unused variable".to_string();
assert_eq!(determine_fix_type(&finding), FixType::Remove);
finding.title = "Missing docstring".to_string();
assert_eq!(determine_fix_type(&finding), FixType::Documentation);
finding.severity = Severity::Critical;
finding.title = "SQL injection vulnerability".to_string();
assert_eq!(determine_fix_type(&finding), FixType::Security);
}
#[test]
fn test_extension_to_language() {
assert_eq!(extension_to_language("py"), "python");
assert_eq!(extension_to_language("js"), "javascript");
assert_eq!(extension_to_language("ts"), "typescript");
assert_eq!(extension_to_language("rs"), "rust");
assert_eq!(extension_to_language("go"), "go");
}
}