use crate::utils::audit::{AuditFinding, AuditReport, AuditSeverity, AuditSummary, SystemInfo};
use anyhow::{Context, Result};
use rand::Rng;
use serde_json::json;
use std::collections::HashMap;
use std::path::Path;
use tera::{Context as TeraContext, Tera};
pub struct TemplateReportGenerator {
tera: Tera,
}
impl TemplateReportGenerator {
pub fn new() -> Result<Self> {
let mut tera = Tera::default();
let typst_template = include_str!("../../templates/audit_report.typ");
tera.add_raw_template("audit_report.typ", typst_template)
.context("Failed to register Typst template")?;
let json_template = include_str!("../../templates/audit_report.json");
tera.add_raw_template("audit_report.json", json_template)
.context("Failed to register JSON template")?;
let html_template = include_str!("../../templates/audit_report.html");
tera.add_raw_template("audit_report.html", html_template)
.context("Failed to register HTML template")?;
let md_template = include_str!("../../templates/audit_summary.md");
tera.add_raw_template("audit_summary.md", md_template)
.context("Failed to register Markdown template")?;
Ok(Self { tera })
}
pub fn generate_report(
&self,
report: &AuditReport,
template_name: &str,
output_path: &Path,
) -> Result<()> {
self.generate_report_with_optional_template(report, template_name, output_path, None)
}
pub fn generate_report_with_optional_template(
&self,
report: &AuditReport,
template_name: &str,
output_path: &Path,
external_template_path: Option<&str>,
) -> Result<()> {
let context = self.create_template_context(report)?;
let rendered = if let Some(external_path) = external_template_path {
let external_content = std::fs::read_to_string(external_path).context(format!(
"Failed to read external template: {}",
external_path
))?;
let mut temp_tera = Tera::default();
temp_tera
.add_raw_template("external", &external_content)
.context(format!(
"Failed to parse external template: {}",
external_path
))?;
temp_tera.render("external", &context).context(format!(
"Failed to render external template: {}",
external_path
))?
} else {
self.tera
.render(template_name, &context)
.context(format!("Failed to render template: {}", template_name))?
};
std::fs::write(output_path, rendered).context("Failed to write rendered report")?;
Ok(())
}
pub fn generate_typst_document(&self, report: &AuditReport, output_path: &Path) -> Result<()> {
self.generate_report(report, "audit_report.typ", output_path)
}
pub fn generate_typst_document_with_template(
&self,
report: &AuditReport,
output_path: &Path,
external_template: Option<&str>,
) -> Result<()> {
self.generate_report_with_optional_template(
report,
"audit_report.typ",
output_path,
external_template,
)
}
pub fn generate_json_report(&self, report: &AuditReport, output_path: &Path) -> Result<()> {
self.generate_report(report, "audit_report.json", output_path)
}
pub fn generate_json_report_with_template(
&self,
report: &AuditReport,
output_path: &Path,
external_template: Option<&str>,
) -> Result<()> {
self.generate_report_with_optional_template(
report,
"audit_report.json",
output_path,
external_template,
)
}
pub fn generate_html_report(&self, report: &AuditReport, output_path: &Path) -> Result<()> {
self.generate_report(report, "audit_report.html", output_path)
}
pub fn generate_html_report_with_template(
&self,
report: &AuditReport,
output_path: &Path,
external_template: Option<&str>,
) -> Result<()> {
self.generate_report_with_optional_template(
report,
"audit_report.html",
output_path,
external_template,
)
}
pub fn generate_markdown_summary(
&self,
report: &AuditReport,
output_path: &Path,
) -> Result<()> {
self.generate_report(report, "audit_summary.md", output_path)
}
pub fn generate_markdown_summary_with_template(
&self,
report: &AuditReport,
output_path: &Path,
external_template: Option<&str>,
) -> Result<()> {
self.generate_report_with_optional_template(
report,
"audit_summary.md",
output_path,
external_template,
)
}
fn create_template_context(&self, report: &AuditReport) -> Result<TeraContext> {
let mut context = TeraContext::new();
context.insert("report", report);
context.insert(
"timestamp",
&report.timestamp.format("%Y-%m-%d %H:%M:%S UTC").to_string(),
);
context.insert("version", &report.version);
context.insert("summary", &report.summary);
context.insert("security_score", &report.summary.security_score);
context.insert("compliance_level", &report.summary.compliance_level);
let categorized_findings = self.categorize_findings(&report.findings);
context.insert("categorized_findings", &categorized_findings);
let severity_findings = self.group_by_severity(&report.findings);
context.insert("severity_findings", &severity_findings);
context.insert("system_info", &report.system_info);
let stats = self.calculate_statistics(&report.findings);
context.insert("statistics", &stats);
context.insert("recommendations", &report.recommendations);
context.insert("compliance_notes", &report.compliance_notes);
context.insert("deeplogic_findings", &report.deeplogic_findings);
context.insert("has_critical", &(report.summary.critical_findings > 0));
context.insert("has_high", &(report.summary.high_findings > 0));
context.insert(
"total_serious",
&(report.summary.critical_findings + report.summary.high_findings),
);
Ok(context)
}
fn categorize_findings<'a>(
&self,
findings: &'a [AuditFinding],
) -> HashMap<String, Vec<&'a AuditFinding>> {
let mut categorized = HashMap::new();
for finding in findings {
categorized
.entry(finding.category.clone())
.or_insert_with(Vec::new)
.push(finding);
}
categorized
}
fn group_by_severity<'a>(
&self,
findings: &'a [AuditFinding],
) -> HashMap<String, Vec<&'a AuditFinding>> {
let mut grouped = HashMap::new();
for finding in findings {
let severity = format!("{:?}", finding.severity);
grouped
.entry(severity)
.or_insert_with(Vec::new)
.push(finding);
}
grouped
}
fn calculate_statistics(&self, findings: &[AuditFinding]) -> serde_json::Value {
let total = findings.len();
let with_cwe = findings.iter().filter(|f| f.cwe_id.is_some()).count();
let with_cvss = findings.iter().filter(|f| f.cvss_score.is_some()).count();
let with_location = findings
.iter()
.filter(|f| f.code_location.is_some())
.count();
let categories: std::collections::HashSet<_> =
findings.iter().map(|f| &f.category).collect();
let unique_categories = categories.len();
let avg_cvss = if with_cvss > 0 {
findings.iter().filter_map(|f| f.cvss_score).sum::<f32>() / with_cvss as f32
} else {
0.0
};
json!({
"total_findings": total,
"findings_with_cwe": with_cwe,
"findings_with_cvss": with_cvss,
"findings_with_location": with_location,
"unique_categories": unique_categories,
"average_cvss_score": avg_cvss,
"coverage_percentage": if total > 0 { (with_location as f32 / total as f32) * 100.0 } else { 0.0 }
})
}
pub fn add_template(&mut self, name: &str, content: &str) -> Result<()> {
self.tera
.add_raw_template(name, content)
.context(format!("Failed to add template: {}", name))?;
Ok(())
}
pub fn list_templates(&self) -> Vec<String> {
self.tera
.get_template_names()
.map(|s| s.to_string())
.collect()
}
}
impl Default for TemplateReportGenerator {
fn default() -> Self {
Self::new().expect("Failed to create template generator")
}
}
pub struct EnhancedAIErrorHandler {
last_error_log: std::sync::Arc<std::sync::Mutex<std::time::Instant>>,
error_log_threshold: std::time::Duration,
}
impl EnhancedAIErrorHandler {
pub fn new() -> Self {
Self {
last_error_log: std::sync::Arc::new(std::sync::Mutex::new(
std::time::Instant::now() - std::time::Duration::from_secs(60),
)),
error_log_threshold: std::time::Duration::from_secs(30), }
}
fn should_log_error(&self) -> bool {
if let Ok(mut last_log) = self.last_error_log.try_lock() {
let now = std::time::Instant::now();
if now.duration_since(*last_log) > self.error_log_threshold {
*last_log = now;
true
} else {
false
}
} else {
false }
}
pub fn handle_ai_error_with_context(
&self,
error: &anyhow::Error,
context: &str,
) -> Option<String> {
if self.should_log_error() {
log::error!("AI Analysis Error in {}: {}", context, error);
log::warn!("AI error logging rate-limited to prevent log flooding");
let mut current_error = error.source();
let mut error_level = 1;
while let Some(err) = current_error {
log::error!(" Error level {}: {}", error_level, err);
current_error = err.source();
error_level += 1;
}
}
Self::handle_ai_error(error, context)
}
pub fn handle_ai_error(error: &anyhow::Error, _context: &str) -> Option<String> {
let error_string = error.to_string();
if error_string.contains("rate limit") || error_string.contains("429") {
log::warn!("Rate limit exceeded - consider reducing API calls or increasing delays");
Some("Rate limit exceeded".to_string())
} else if error_string.contains("timeout") {
log::warn!("AI request timeout - network or service issues");
Some("Request timeout".to_string())
} else if error_string.contains("api_key") || error_string.contains("unauthorized") {
log::error!("AI API authentication failed - check API key");
Some("Authentication failed".to_string())
} else if error_string.contains("network") || error_string.contains("connection") {
log::warn!("Network connectivity issues with AI service");
Some("Network error".to_string())
} else {
log::error!("Unknown AI service error - disabling AI analysis for this session");
None
}
}
pub fn generate_fallback_analysis(finding_title: &str, finding_category: &str) -> String {
format!(
"AI analysis unavailable. Manual review recommended for {} in category {}. \
Consider checking relevant security guidelines and best practices.",
finding_title, finding_category
)
}
pub fn should_disable_ai(error: &anyhow::Error) -> bool {
let error_string = error.to_string();
error_string.contains("api_key")
|| error_string.contains("unauthorized")
|| error_string.contains("quota")
|| error_string.contains("suspended")
}
pub fn get_retry_strategy(error: &anyhow::Error) -> (bool, std::time::Duration) {
let error_string = error.to_string();
let base_delay = if error_string.contains("rate limit") || error_string.contains("429") {
60 } else if error_string.contains("timeout") || error_string.contains("network") {
5 } else if error_string.contains("server") || error_string.contains("503") {
30 } else {
return (false, std::time::Duration::from_secs(0)); };
let jitter = rand::random::<f64>() * 0.5 + 0.5; let backoff_delay = (base_delay as f64 * jitter) as u64;
(true, std::time::Duration::from_secs(backoff_delay))
}
}
#[cfg(test)]
mod tests {
use super::*;
use chrono::Utc;
use std::collections::HashMap;
fn create_test_report() -> AuditReport {
let findings = vec![AuditFinding {
id: "TEST-001".to_string(),
title: "Test finding".to_string(),
description: "Test description".to_string(),
severity: AuditSeverity::High,
category: "Test".to_string(),
cwe_id: Some("CWE-123".to_string()),
cvss_score: Some(7.5),
impact: "Test impact".to_string(),
recommendation: "Test recommendation".to_string(),
code_location: Some("test.rs".to_string()),
references: vec!["https://example.com".to_string()],
}];
let summary = AuditSummary {
total_findings: 1,
critical_findings: 0,
high_findings: 1,
medium_findings: 0,
low_findings: 0,
info_findings: 0,
security_score: 75.0,
compliance_level: "Good".to_string(),
};
let system_info = SystemInfo {
rust_version: "1.70.0".to_string(),
solana_version: Some("1.16.0".to_string()),
os_info: "Linux".to_string(),
architecture: "x86_64".to_string(),
dependencies: HashMap::new(),
};
AuditReport {
timestamp: Utc::now(),
version: "1.0.0".to_string(),
summary,
findings,
deeplogic_findings: Vec::new(), system_info,
recommendations: vec!["Test recommendation".to_string()],
compliance_notes: vec!["Test compliance note".to_string()],
}
}
#[test]
fn test_template_generator_creation() {
let generator = TemplateReportGenerator::new();
assert!(generator.is_ok());
}
#[test]
fn test_categorize_findings() {
let generator = TemplateReportGenerator::new().unwrap();
let report = create_test_report();
let categorized = generator.categorize_findings(&report.findings);
assert!(categorized.contains_key("Test"));
assert_eq!(categorized["Test"].len(), 1);
}
#[test]
fn test_ai_error_handler() {
let error = anyhow::anyhow!("rate limit exceeded");
let result = EnhancedAIErrorHandler::handle_ai_error(&error, "test");
assert_eq!(result, Some("Rate limit exceeded".to_string()));
let retry_info = EnhancedAIErrorHandler::get_retry_strategy(&error);
assert!(retry_info.0); assert!(retry_info.1.as_secs() > 0); }
#[test]
fn test_fallback_analysis() {
let fallback =
EnhancedAIErrorHandler::generate_fallback_analysis("Test Finding", "Security");
assert!(fallback.contains("Test Finding"));
assert!(fallback.contains("Security"));
}
}