mockforge_intelligence/threat_modeling/
remediation_generator.rs1use super::types::{RemediationSuggestion, ThreatFinding};
7use crate::intelligent_behavior::config::BehaviorModelConfig;
8use crate::intelligent_behavior::llm_client::LlmClient;
9use crate::intelligent_behavior::types::LlmGenerationRequest;
10use mockforge_foundation::Result;
11
12pub struct RemediationGenerator {
14 llm_client: Option<LlmClient>,
16 enabled: bool,
18}
19
20impl RemediationGenerator {
21 pub fn new(
23 enabled: bool,
24 llm_provider: String,
25 llm_model: String,
26 api_key: Option<String>,
27 ) -> Result<Self> {
28 let llm_client = if enabled {
29 let llm_config = BehaviorModelConfig {
30 llm_provider: llm_provider.clone(),
31 model: llm_model.clone(),
32 api_key: api_key.clone(),
33 api_endpoint: None,
34 temperature: 0.3, max_tokens: 2000,
36 rules: crate::intelligent_behavior::BehaviorRules::default(),
37 };
38
39 Some(LlmClient::new(llm_config))
40 } else {
41 None
42 };
43
44 Ok(Self {
45 llm_client,
46 enabled,
47 })
48 }
49
50 pub async fn generate_remediations(
52 &self,
53 findings: &[ThreatFinding],
54 ) -> Result<Vec<RemediationSuggestion>> {
55 if !self.enabled || self.llm_client.is_none() {
56 return Ok(self.generate_basic_remediations(findings));
57 }
58
59 let mut suggestions = Vec::new();
60
61 for finding in findings {
62 if let Some(ref llm_client) = self.llm_client {
63 match self.generate_ai_remediation(llm_client, finding).await {
64 Ok(suggestion) => suggestions.push(suggestion),
65 Err(e) => {
66 suggestions.push(self.generate_basic_remediation(finding));
68 tracing::warn!("Failed to generate AI remediation: {}", e);
69 }
70 }
71 }
72 }
73
74 Ok(suggestions)
75 }
76
77 async fn generate_ai_remediation(
79 &self,
80 llm_client: &LlmClient,
81 finding: &ThreatFinding,
82 ) -> Result<RemediationSuggestion> {
83 let prompt = self.build_remediation_prompt(finding);
84
85 let request = LlmGenerationRequest::new(self.get_system_prompt(), prompt)
86 .with_temperature(0.3)
87 .with_max_tokens(2000);
88
89 let response = llm_client.generate(&request).await?;
90
91 let suggestion_text = response
93 .get("suggestion")
94 .and_then(|v| v.as_str())
95 .map(|s| s.to_string())
96 .or_else(|| {
97 response
99 .get("response")
100 .and_then(|v| v.as_str())
101 .map(|s| s.to_string())
102 .or_else(|| {
103 serde_json::to_string(&response).ok()
105 })
106 })
107 .unwrap_or_else(|| "No remediation suggestion available".to_string());
108
109 let code_example =
110 response.get("code_example").and_then(|v| v.as_str()).map(|s| s.to_string());
111
112 let confidence = response.get("confidence").and_then(|v| v.as_f64()).unwrap_or(0.7);
113
114 Ok(RemediationSuggestion {
115 finding_id: format!("finding_{}", finding.field_path.as_deref().unwrap_or("unknown")),
116 suggestion: suggestion_text,
117 code_example,
118 confidence,
119 ai_generated: true,
120 priority: self.calculate_priority(finding),
121 })
122 }
123
124 fn build_remediation_prompt(&self, finding: &ThreatFinding) -> String {
126 format!(
127 r#"Generate a remediation suggestion for this API security finding:
128
129Finding Type: {:?}
130Severity: {:?}
131Description: {}
132Field Path: {}
133
134Provide:
1351. A clear, actionable remediation suggestion
1362. A code example showing how to fix it (if applicable)
1373. Confidence score (0.0-1.0)
138
139Format your response as JSON:
140{{
141 "suggestion": "detailed suggestion text",
142 "code_example": "example code or schema change",
143 "confidence": 0.8
144}}"#,
145 finding.finding_type,
146 finding.severity,
147 finding.description,
148 finding.field_path.as_deref().unwrap_or("N/A")
149 )
150 }
151
152 fn get_system_prompt(&self) -> String {
154 "You are an expert API security analyst specializing in contract security and threat remediation. Provide clear, actionable remediation suggestions with code examples when applicable.".to_string()
155 }
156
157 fn calculate_priority(&self, finding: &ThreatFinding) -> u32 {
159 match finding.severity {
160 super::types::ThreatLevel::Critical => 1,
161 super::types::ThreatLevel::High => 2,
162 super::types::ThreatLevel::Medium => 3,
163 super::types::ThreatLevel::Low => 4,
164 }
165 }
166
167 fn generate_basic_remediations(
169 &self,
170 findings: &[ThreatFinding],
171 ) -> Vec<RemediationSuggestion> {
172 findings.iter().map(|f| self.generate_basic_remediation(f)).collect()
173 }
174
175 fn generate_basic_remediation(&self, finding: &ThreatFinding) -> RemediationSuggestion {
177 let (suggestion, code_example) = match finding.finding_type {
178 super::types::ThreatCategory::UnboundedArrays => (
179 "Add maxItems constraint to array schema to prevent DoS attacks".to_string(),
180 Some(
181 r#"{
182 "type": "array",
183 "items": {...},
184 "maxItems": 100
185}"#
186 .to_string(),
187 ),
188 ),
189 super::types::ThreatCategory::PiiExposure => (
190 "Review field name and ensure PII is properly masked or removed from responses"
191 .to_string(),
192 None,
193 ),
194 super::types::ThreatCategory::StackTraceLeakage => (
195 "Sanitize error messages to remove stack traces and internal details".to_string(),
196 Some(
197 r#"{
198 "error": {
199 "message": "An error occurred",
200 "code": "ERROR_CODE"
201 }
202}"#
203 .to_string(),
204 ),
205 ),
206 super::types::ThreatCategory::ExcessiveOptionalFields => (
207 "Consider making more fields required or splitting into separate schemas"
208 .to_string(),
209 None,
210 ),
211 _ => (format!("Address the {} issue in the API contract", finding.finding_type), None),
212 };
213
214 RemediationSuggestion {
215 finding_id: format!("finding_{}", finding.field_path.as_deref().unwrap_or("unknown")),
216 suggestion,
217 code_example,
218 confidence: 0.6,
219 ai_generated: false,
220 priority: self.calculate_priority(finding),
221 }
222 }
223}