1use anyhow::{Context, Result};
10use tracing::instrument;
11
12use super::types::{Finding, ValidatedFinding, ValidationResult};
13use crate::ai::client::AiClient;
14use crate::ai::provider::AiProvider;
15use crate::ai::types::{ChatCompletionRequest, ChatMessage, ResponseFormat};
16
17const CONTEXT_LINES: usize = 10;
19
20#[derive(serde::Deserialize)]
22struct ValidationResponse {
23 results: Vec<ValidationResult>,
24}
25
26#[derive(Debug)]
31pub struct SecurityValidator {
32 ai_client: AiClient,
34}
35
36impl SecurityValidator {
37 pub fn new(ai_client: AiClient) -> Self {
43 Self { ai_client }
44 }
45
46 #[instrument(skip(self, findings, file_contents), fields(count = findings.len()))]
60 pub async fn validate_findings_batch(
61 &self,
62 findings: &[Finding],
63 file_contents: &std::collections::HashMap<String, String>,
64 ) -> Result<Vec<ValidatedFinding>> {
65 if findings.is_empty() {
66 return Ok(Vec::new());
67 }
68
69 let prompt = Self::build_batch_validation_prompt(findings, file_contents);
71
72 let request = ChatCompletionRequest {
74 model: self.ai_client.model().to_string(),
75 messages: vec![
76 ChatMessage {
77 role: "system".to_string(),
78 content: Some(Self::build_system_prompt()),
79 reasoning: None,
80 },
81 ChatMessage {
82 role: "user".to_string(),
83 content: Some(prompt),
84 reasoning: None,
85 },
86 ],
87 response_format: Some(ResponseFormat {
88 format_type: "json_object".to_string(),
89 json_schema: None,
90 }),
91 max_tokens: Some(self.ai_client.max_tokens()),
92 temperature: Some(0.3),
93 };
94
95 match self.send_and_parse(&request).await {
97 Ok(results) => {
98 let mut validated = Vec::new();
100 for (i, finding) in findings.iter().enumerate() {
101 if let Some(result) = results.iter().find(|r| r.index == i) {
102 validated.push(ValidatedFinding {
103 finding: finding.clone(),
104 is_valid: result.is_valid,
105 reasoning: result.reasoning.clone(),
106 model_version: Some(self.ai_client.model().to_string()),
107 });
108 } else {
109 validated.push(Self::fallback_validation(finding));
111 }
112 }
113 Ok(validated)
114 }
115 Err(e) => {
116 tracing::warn!(error = %e, "LLM validation failed, using pattern confidence");
118 Ok(findings.iter().map(Self::fallback_validation).collect())
119 }
120 }
121 }
122
123 fn build_system_prompt() -> String {
125 r#"You are a security code reviewer. Analyze the provided security findings and determine if they are real vulnerabilities or false positives.
126
127Your response MUST be valid JSON with this exact schema:
128{
129 "results": [
130 {
131 "index": 0,
132 "is_valid": true,
133 "reasoning": "Brief explanation of why this is/isn't a real issue"
134 }
135 ]
136}
137
138Guidelines:
139- index: The 0-based index of the finding in the batch
140- is_valid: true if this is a real security issue, false if it's a false positive
141- reasoning: 1-2 sentence explanation of your decision
142
143Consider:
1441. Context: Is the code actually vulnerable in its usage context?
1452. False positives: Test data, comments, documentation, or safe patterns?
1463. Severity: Does the finding match the claimed severity?
1474. Mitigation: Are there compensating controls in place?
148
149Be conservative: when in doubt, mark as valid to avoid missing real issues."#
150 .to_string()
151 }
152
153 fn build_batch_validation_prompt(
155 findings: &[Finding],
156 file_contents: &std::collections::HashMap<String, String>,
157 ) -> String {
158 use std::fmt::Write;
159
160 let mut prompt = String::new();
161 prompt.push_str("Analyze these security findings:\n\n");
162
163 for (i, finding) in findings.iter().enumerate() {
164 let _ = writeln!(prompt, "Finding {i}:");
165 let _ = writeln!(prompt, " Pattern: {}", finding.pattern_id);
166 let _ = writeln!(prompt, " Description: {}", finding.description);
167 let _ = writeln!(
168 prompt,
169 " Severity: {:?}, Confidence: {:?}",
170 finding.severity, finding.confidence
171 );
172 let _ = writeln!(
173 prompt,
174 " File: {}:{}",
175 finding.file_path, finding.line_number
176 );
177 let _ = writeln!(prompt, " Matched: {}", finding.matched_text);
178
179 if let Some(snippet) =
181 extract_snippet(file_contents.get(&finding.file_path), finding.line_number)
182 {
183 let _ = writeln!(prompt, " Context:\n{snippet}");
184 }
185
186 prompt.push('\n');
187 }
188
189 prompt
190 }
191
192 async fn send_and_parse(
194 &self,
195 request: &ChatCompletionRequest,
196 ) -> Result<Vec<ValidationResult>> {
197 let completion = self.ai_client.send_request_inner(request).await?;
199
200 let content = completion
202 .choices
203 .first()
204 .and_then(|c| {
205 c.message
206 .content
207 .clone()
208 .or_else(|| c.message.reasoning.clone())
209 })
210 .context("No response from AI model")?;
211
212 let response: ValidationResponse = serde_json::from_str(&content)
214 .context("Failed to parse validation response as JSON")?;
215
216 Ok(response.results)
217 }
218
219 fn fallback_validation(finding: &Finding) -> ValidatedFinding {
221 use super::types::Confidence;
222
223 let is_valid = matches!(finding.confidence, Confidence::High | Confidence::Medium);
224 let reasoning = format!(
225 "LLM validation unavailable, using pattern confidence: {:?}",
226 finding.confidence
227 );
228
229 ValidatedFinding {
230 finding: finding.clone(),
231 is_valid,
232 reasoning,
233 model_version: None,
234 }
235 }
236}
237
238fn extract_snippet(content: Option<&String>, line_number: usize) -> Option<String> {
249 use std::fmt::Write;
250
251 let content = content?;
252 let lines: Vec<&str> = content.lines().collect();
253
254 if line_number == 0 || line_number > lines.len() {
255 return None;
256 }
257
258 let target_idx = line_number - 1;
260 let start = target_idx.saturating_sub(CONTEXT_LINES);
261 let end = (target_idx + CONTEXT_LINES + 1).min(lines.len());
262
263 let mut snippet = String::new();
264 for (i, line) in lines[start..end].iter().enumerate() {
265 let line_num = start + i + 1;
266 let marker = if line_num == line_number { ">" } else { " " };
267 let _ = writeln!(snippet, "{marker} {line_num:4} | {line}");
268 }
269
270 Some(snippet)
271}
272
273#[cfg(test)]
274mod tests {
275 use super::*;
276 use crate::security::types::{Confidence, Severity};
277
278 #[test]
279 fn test_extract_snippet_with_context() {
280 let content = "line 1\nline 2\nline 3\nline 4\nline 5\nline 6\nline 7\n".to_string();
281 let snippet = extract_snippet(Some(&content), 4);
282
283 assert!(snippet.is_some());
284 let snippet = snippet.unwrap();
285 assert!(snippet.contains("> 4 | line 4"));
286 assert!(snippet.contains(" 1 | line 1"));
287 assert!(snippet.contains(" 7 | line 7"));
288 }
289
290 #[test]
291 fn test_extract_snippet_at_start() {
292 let content = "line 1\nline 2\nline 3\n".to_string();
293 let snippet = extract_snippet(Some(&content), 1);
294
295 assert!(snippet.is_some());
296 let snippet = snippet.unwrap();
297 assert!(snippet.contains("> 1 | line 1"));
298 assert!(!snippet.contains(" 0 |"));
299 }
300
301 #[test]
302 fn test_extract_snippet_at_end() {
303 let content = "line 1\nline 2\nline 3\n".to_string();
304 let snippet = extract_snippet(Some(&content), 3);
305
306 assert!(snippet.is_some());
307 let snippet = snippet.unwrap();
308 assert!(snippet.contains("> 3 | line 3"));
309 }
310
311 #[test]
312 fn test_extract_snippet_invalid_line() {
313 let content = "line 1\nline 2\n".to_string();
314 let snippet = extract_snippet(Some(&content), 10);
315
316 assert!(snippet.is_none());
317 }
318
319 #[test]
320 fn test_fallback_validation_high_confidence() {
321 let finding = Finding {
322 pattern_id: "test-pattern".to_string(),
323 description: "Test finding".to_string(),
324 severity: Severity::High,
325 confidence: Confidence::High,
326 file_path: "test.rs".to_string(),
327 line_number: 1,
328 matched_text: "test".to_string(),
329 cwe: None,
330 };
331
332 let validated = SecurityValidator::fallback_validation(&finding);
333 assert!(validated.is_valid);
334 assert!(validated.reasoning.contains("High"));
335 }
336
337 #[test]
338 fn test_fallback_validation_low_confidence() {
339 let finding = Finding {
340 pattern_id: "test-pattern".to_string(),
341 description: "Test finding".to_string(),
342 severity: Severity::High,
343 confidence: Confidence::Low,
344 file_path: "test.rs".to_string(),
345 line_number: 1,
346 matched_text: "test".to_string(),
347 cwe: None,
348 };
349
350 let validated = SecurityValidator::fallback_validation(&finding);
351 assert!(!validated.is_valid);
352 assert!(validated.reasoning.contains("Low"));
353 }
354
355 #[test]
356 fn test_build_system_prompt() {
357 let prompt = SecurityValidator::build_system_prompt();
358 assert!(prompt.contains("security code reviewer"));
359 assert!(prompt.contains("\"results\""));
360 assert!(prompt.contains("\"index\""));
361 assert!(prompt.contains("\"is_valid\""));
362 assert!(prompt.contains("\"reasoning\""));
363 }
364
365 #[test]
366 fn test_parse_validation_response() {
367 let json = r#"{
368 "results": [
369 {
370 "index": 0,
371 "is_valid": true,
372 "reasoning": "This is a real vulnerability"
373 },
374 {
375 "index": 1,
376 "is_valid": false,
377 "reasoning": "This is test data"
378 }
379 ]
380 }"#;
381
382 let response: ValidationResponse = serde_json::from_str(json).unwrap();
383 assert_eq!(response.results.len(), 2);
384 assert_eq!(response.results[0].index, 0);
385 assert!(response.results[0].is_valid);
386 assert_eq!(response.results[1].index, 1);
387 assert!(!response.results[1].is_valid);
388 }
389}