Skip to main content

mockforge_intelligence/threat_modeling/
remediation_generator.rs

1//! AI-powered remediation suggestion generator
2//!
3//! This module generates remediation suggestions for threat findings
4//! using LLM analysis.
5
6use super::types::{RemediationSuggestion, ThreatFinding};
7use crate::intelligent_behavior::config::BehaviorModelConfig;
8use crate::intelligent_behavior::llm_client::LlmClient;
9use crate::intelligent_behavior::types::LlmGenerationRequest;
10use mockforge_foundation::Result;
11
12/// Remediation generator using AI
13pub struct RemediationGenerator {
14    /// LLM client
15    llm_client: Option<LlmClient>,
16    /// Whether AI generation is enabled
17    enabled: bool,
18}
19
20impl RemediationGenerator {
21    /// Create a new remediation generator
22    pub fn new(
23        enabled: bool,
24        llm_provider: String,
25        llm_model: String,
26        api_key: Option<String>,
27    ) -> Result<Self> {
28        let llm_client = if enabled {
29            let llm_config = BehaviorModelConfig {
30                llm_provider: llm_provider.clone(),
31                model: llm_model.clone(),
32                api_key: api_key.clone(),
33                api_endpoint: None,
34                temperature: 0.3, // Lower temperature for precise suggestions
35                max_tokens: 2000,
36                rules: crate::intelligent_behavior::BehaviorRules::default(),
37            };
38
39            Some(LlmClient::new(llm_config))
40        } else {
41            None
42        };
43
44        Ok(Self {
45            llm_client,
46            enabled,
47        })
48    }
49
50    /// Generate remediation suggestions for findings
51    pub async fn generate_remediations(
52        &self,
53        findings: &[ThreatFinding],
54    ) -> Result<Vec<RemediationSuggestion>> {
55        if !self.enabled || self.llm_client.is_none() {
56            return Ok(self.generate_basic_remediations(findings));
57        }
58
59        let mut suggestions = Vec::new();
60
61        for finding in findings {
62            if let Some(ref llm_client) = self.llm_client {
63                match self.generate_ai_remediation(llm_client, finding).await {
64                    Ok(suggestion) => suggestions.push(suggestion),
65                    Err(e) => {
66                        // Fallback to basic remediation on error
67                        suggestions.push(self.generate_basic_remediation(finding));
68                        tracing::warn!("Failed to generate AI remediation: {}", e);
69                    }
70                }
71            }
72        }
73
74        Ok(suggestions)
75    }
76
77    /// Generate AI-powered remediation
78    async fn generate_ai_remediation(
79        &self,
80        llm_client: &LlmClient,
81        finding: &ThreatFinding,
82    ) -> Result<RemediationSuggestion> {
83        let prompt = self.build_remediation_prompt(finding);
84
85        let request = LlmGenerationRequest::new(self.get_system_prompt(), prompt)
86            .with_temperature(0.3)
87            .with_max_tokens(2000);
88
89        let response = llm_client.generate(&request).await?;
90
91        // Response is already a serde_json::Value
92        let suggestion_text = response
93            .get("suggestion")
94            .and_then(|v| v.as_str())
95            .map(|s| s.to_string())
96            .or_else(|| {
97                // Fallback: try to extract text from "response" field or use the whole value as string
98                response
99                    .get("response")
100                    .and_then(|v| v.as_str())
101                    .map(|s| s.to_string())
102                    .or_else(|| {
103                        // Last resort: serialize the whole response as a string
104                        serde_json::to_string(&response).ok()
105                    })
106            })
107            .unwrap_or_else(|| "No remediation suggestion available".to_string());
108
109        let code_example =
110            response.get("code_example").and_then(|v| v.as_str()).map(|s| s.to_string());
111
112        let confidence = response.get("confidence").and_then(|v| v.as_f64()).unwrap_or(0.7);
113
114        Ok(RemediationSuggestion {
115            finding_id: format!("finding_{}", finding.field_path.as_deref().unwrap_or("unknown")),
116            suggestion: suggestion_text,
117            code_example,
118            confidence,
119            ai_generated: true,
120            priority: self.calculate_priority(finding),
121        })
122    }
123
124    /// Build prompt for remediation generation
125    fn build_remediation_prompt(&self, finding: &ThreatFinding) -> String {
126        format!(
127            r#"Generate a remediation suggestion for this API security finding:
128
129Finding Type: {:?}
130Severity: {:?}
131Description: {}
132Field Path: {}
133
134Provide:
1351. A clear, actionable remediation suggestion
1362. A code example showing how to fix it (if applicable)
1373. Confidence score (0.0-1.0)
138
139Format your response as JSON:
140{{
141  "suggestion": "detailed suggestion text",
142  "code_example": "example code or schema change",
143  "confidence": 0.8
144}}"#,
145            finding.finding_type,
146            finding.severity,
147            finding.description,
148            finding.field_path.as_deref().unwrap_or("N/A")
149        )
150    }
151
152    /// Get system prompt
153    fn get_system_prompt(&self) -> String {
154        "You are an expert API security analyst specializing in contract security and threat remediation. Provide clear, actionable remediation suggestions with code examples when applicable.".to_string()
155    }
156
157    /// Calculate priority based on severity
158    fn calculate_priority(&self, finding: &ThreatFinding) -> u32 {
159        match finding.severity {
160            super::types::ThreatLevel::Critical => 1,
161            super::types::ThreatLevel::High => 2,
162            super::types::ThreatLevel::Medium => 3,
163            super::types::ThreatLevel::Low => 4,
164        }
165    }
166
167    /// Generate basic remediation without AI
168    fn generate_basic_remediations(
169        &self,
170        findings: &[ThreatFinding],
171    ) -> Vec<RemediationSuggestion> {
172        findings.iter().map(|f| self.generate_basic_remediation(f)).collect()
173    }
174
175    /// Generate a basic remediation for a finding
176    fn generate_basic_remediation(&self, finding: &ThreatFinding) -> RemediationSuggestion {
177        let (suggestion, code_example) = match finding.finding_type {
178            super::types::ThreatCategory::UnboundedArrays => (
179                "Add maxItems constraint to array schema to prevent DoS attacks".to_string(),
180                Some(
181                    r#"{
182  "type": "array",
183  "items": {...},
184  "maxItems": 100
185}"#
186                    .to_string(),
187                ),
188            ),
189            super::types::ThreatCategory::PiiExposure => (
190                "Review field name and ensure PII is properly masked or removed from responses"
191                    .to_string(),
192                None,
193            ),
194            super::types::ThreatCategory::StackTraceLeakage => (
195                "Sanitize error messages to remove stack traces and internal details".to_string(),
196                Some(
197                    r#"{
198  "error": {
199    "message": "An error occurred",
200    "code": "ERROR_CODE"
201  }
202}"#
203                    .to_string(),
204                ),
205            ),
206            super::types::ThreatCategory::ExcessiveOptionalFields => (
207                "Consider making more fields required or splitting into separate schemas"
208                    .to_string(),
209                None,
210            ),
211            _ => (format!("Address the {} issue in the API contract", finding.finding_type), None),
212        };
213
214        RemediationSuggestion {
215            finding_id: format!("finding_{}", finding.field_path.as_deref().unwrap_or("unknown")),
216            suggestion,
217            code_example,
218            confidence: 0.6,
219            ai_generated: false,
220            priority: self.calculate_priority(finding),
221        }
222    }
223}