use crate::utils::ast_analyzer::AstAnalyzer;
use crate::utils::audit::{AuditFinding, CodeSnippet};
use anyhow::{Context, Result};
use std::fs;
pub struct CodeExtractor {
ast_analyzer: AstAnalyzer,
}
impl CodeExtractor {
pub fn new() -> Self {
Self {
ast_analyzer: AstAnalyzer::new(),
}
}
pub fn extract_snippet_around_line(
file_path: &str,
target_line: usize,
context_lines: usize,
) -> Result<CodeSnippet> {
let content = fs::read_to_string(file_path)
.with_context(|| format!("Failed to read file: {}", file_path))?;
let lines: Vec<&str> = content.lines().collect();
let total_lines = lines.len();
if target_line == 0 || target_line > total_lines {
anyhow::bail!(
"Invalid line number: {} (file has {} lines)",
target_line,
total_lines
);
}
let start_line = if target_line <= context_lines {
1
} else {
target_line - context_lines
};
let end_line = std::cmp::min(target_line + context_lines, total_lines);
let snippet_lines: Vec<&str> = lines
.iter()
.skip(start_line - 1)
.take(end_line - start_line + 1)
.cloned()
.collect();
let snippet_content = snippet_lines.join("\n");
Ok(CodeSnippet {
file_path: file_path.to_string(),
start_line,
end_line,
content: snippet_content,
context: Some(format!("Code context around line {}", target_line)),
})
}
pub fn extract_snippet_by_pattern(
file_path: &str,
pattern: &str,
context_lines: usize,
) -> Result<Vec<CodeSnippet>> {
let content = fs::read_to_string(file_path)
.with_context(|| format!("Failed to read file: {}", file_path))?;
let lines: Vec<&str> = content.lines().collect();
let mut snippets = Vec::new();
for (line_num, line) in lines.iter().enumerate() {
if line.contains(pattern) {
let target_line = line_num + 1;
if let Ok(snippet) =
Self::extract_snippet_around_line(file_path, target_line, context_lines)
{
snippets.push(snippet);
}
}
}
Ok(snippets)
}
pub fn extract_problematic_code(finding: &AuditFinding) -> Result<CodeSnippet> {
if let Some(code_location) = &finding.code_location {
if let Some((file_path, line_str)) = code_location.split_once(':') {
if let Ok(line_num) = line_str.parse::<usize>() {
return Self::extract_snippet_around_line(file_path, line_num, 3);
}
}
let patterns = Self::extract_patterns_from_finding(finding);
if let Ok(snippets) = Self::extract_snippet_by_pattern(code_location, &patterns[0], 3) {
if let Some(snippet) = snippets.first() {
return Ok(snippet.clone());
}
}
Self::extract_snippet_around_line(code_location, 1, 10)
} else {
anyhow::bail!("No code location specified in finding")
}
}
fn extract_patterns_from_finding(finding: &AuditFinding) -> Vec<String> {
let mut patterns = Vec::new();
match finding.category.as_str() {
"Solana Security" => {
patterns.extend(vec![
"is_signer".to_string(),
"invoke".to_string(),
"AccountInfo".to_string(),
"lamports".to_string(),
]);
}
"Authentication & Authorization" => {
patterns.extend(vec![
"authority".to_string(),
"signer".to_string(),
"owner".to_string(),
]);
}
"Trading Security" => {
patterns.extend(vec![
"slippage".to_string(),
"deadline".to_string(),
"swap".to_string(),
"trade".to_string(),
]);
}
_ => {
if finding.title.to_lowercase().contains("token") {
patterns.push("token".to_string());
}
if finding.title.to_lowercase().contains("transfer") {
patterns.push("transfer".to_string());
}
}
}
if patterns.is_empty() {
patterns.push("fn ".to_string()); }
patterns
}
pub fn generate_enhanced_fix(
&self,
problematic_code: &CodeSnippet,
finding: &AuditFinding,
) -> Result<CodeSnippet> {
match self
.ast_analyzer
.generate_ast_based_fix(finding, problematic_code)
{
Ok(enhanced_fix) => {
println!("✨ Generated AST-based fix suggestion");
Ok(enhanced_fix)
}
Err(e) => {
println!(
"⚠️ AST analysis failed, falling back to pattern-based: {}",
e
);
Self::generate_suggested_fix(problematic_code, finding)
}
}
}
pub fn generate_suggested_fix(
problematic_code: &CodeSnippet,
finding: &AuditFinding,
) -> Result<CodeSnippet> {
let fixed_content = Self::apply_basic_fixes(&problematic_code.content, finding);
Ok(CodeSnippet {
file_path: problematic_code.file_path.clone(),
start_line: problematic_code.start_line,
end_line: problematic_code.end_line,
content: fixed_content,
context: Some("Pattern-based fix suggestion".to_string()),
})
}
fn apply_basic_fixes(content: &str, finding: &AuditFinding) -> String {
let mut fixed_content = content.to_string();
match finding.category.as_str() {
"Authentication & Authorization" => {
if !content.contains("is_signer")
&& (content.contains("AccountInfo") || content.contains("ctx.accounts"))
{
fixed_content = fixed_content.replace(
"let user_account = &ctx.accounts.user_account;",
"let user_account = &ctx.accounts.user_account;\n require!(user_account.is_signer, ErrorCode::MissingSignature);"
);
}
}
"Solana Security" => {
if content.contains("user_account") && !content.contains("key()") {
fixed_content = fixed_content.replace(
"user_account.balance += amount;",
"require!(user_account.key() == expected_user_key, ErrorCode::InvalidUser);\n user_account.balance += amount;"
);
}
}
"Trading Security" => {
if content.contains("claim_rewards") && !content.contains("slippage") {
fixed_content = fixed_content.replace(
"let rewards_to_claim = pool_state.accumulated_rewards * user_share;",
"let rewards_to_claim = pool_state.accumulated_rewards * user_share;\n require!(rewards_to_claim >= min_expected_rewards, ErrorCode::SlippageExceeded);"
);
}
}
_ => {
if content.contains("unwrap()") {
fixed_content = fixed_content
.replace("unwrap()", "map_err(|_| ErrorCode::OperationFailed)?");
}
}
}
fixed_content
}
}
impl Default for CodeExtractor {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::utils::audit::AuditSeverity;
use std::fs;
use tempfile::NamedTempFile;
#[test]
fn test_extract_snippet_around_line() {
let temp_file = NamedTempFile::new().unwrap();
let content = "line 1\nline 2\nline 3\nline 4\nline 5\n";
fs::write(temp_file.path(), content).unwrap();
let snippet =
CodeExtractor::extract_snippet_around_line(temp_file.path().to_str().unwrap(), 3, 1)
.unwrap();
assert_eq!(snippet.start_line, 2);
assert_eq!(snippet.end_line, 4);
assert!(snippet.content.contains("line 2"));
assert!(snippet.content.contains("line 3"));
assert!(snippet.content.contains("line 4"));
}
#[test]
fn test_extract_snippet_by_pattern() {
let temp_file = NamedTempFile::new().unwrap();
let content = "fn test1() {}\nlet x = 5;\nfn test2() {}\nlet y = 10;\n";
fs::write(temp_file.path(), content).unwrap();
let snippets =
CodeExtractor::extract_snippet_by_pattern(temp_file.path().to_str().unwrap(), "fn ", 1)
.unwrap();
assert_eq!(snippets.len(), 2);
assert!(snippets[0].content.contains("test1"));
assert!(snippets[1].content.contains("test2"));
}
#[test]
fn test_generate_suggested_fix() {
let problematic_code = CodeSnippet {
file_path: "test.rs".to_string(),
start_line: 1,
end_line: 2,
content:
"let user_account = &ctx.accounts.user_account;\nuser_account.balance += amount;"
.to_string(),
context: None,
};
let finding = AuditFinding {
id: "test".to_string(),
title: "Missing signer validation".to_string(),
description: "Test".to_string(),
severity: AuditSeverity::High,
category: "Authentication & Authorization".to_string(),
cwe_id: None,
cvss_score: None,
impact: "Test".to_string(),
recommendation: "Test".to_string(),
code_location: None,
references: vec![],
};
let fixed_code =
CodeExtractor::generate_suggested_fix(&problematic_code, &finding).unwrap();
assert!(fixed_code.content.contains("is_signer"));
}
}