1use anyhow::{Context, Result};
10use tracing::instrument;
11
12use super::types::{Finding, ValidatedFinding, ValidationResult};
13use crate::ai::client::AiClient;
14use crate::ai::provider::AiProvider;
15use crate::ai::types::{ChatCompletionRequest, ChatMessage, ResponseFormat};
16
17const CONTEXT_LINES: usize = 10;
19
20#[derive(serde::Deserialize)]
22struct ValidationResponse {
23 results: Vec<ValidationResult>,
24}
25
26#[derive(Debug)]
31pub struct SecurityValidator {
32 ai_client: AiClient,
34}
35
36impl SecurityValidator {
37 pub fn new(ai_client: AiClient) -> Self {
43 Self { ai_client }
44 }
45
46 #[instrument(skip(self, findings, file_contents), fields(count = findings.len()))]
60 pub async fn validate_findings_batch(
61 &self,
62 findings: &[Finding],
63 file_contents: &std::collections::HashMap<String, String>,
64 ) -> Result<Vec<ValidatedFinding>> {
65 if findings.is_empty() {
66 return Ok(Vec::new());
67 }
68
69 let prompt = Self::build_batch_validation_prompt(findings, file_contents);
71
72 let request = ChatCompletionRequest {
74 model: self.ai_client.model().to_string(),
75 messages: vec![
76 ChatMessage {
77 role: "system".to_string(),
78 content: Self::build_system_prompt(),
79 },
80 ChatMessage {
81 role: "user".to_string(),
82 content: prompt,
83 },
84 ],
85 response_format: Some(ResponseFormat {
86 format_type: "json_object".to_string(),
87 json_schema: None,
88 }),
89 max_tokens: Some(self.ai_client.max_tokens()),
90 temperature: Some(0.3),
91 };
92
93 match self.send_and_parse(&request).await {
95 Ok(results) => {
96 let mut validated = Vec::new();
98 for (i, finding) in findings.iter().enumerate() {
99 if let Some(result) = results.iter().find(|r| r.index == i) {
100 validated.push(ValidatedFinding {
101 finding: finding.clone(),
102 is_valid: result.is_valid,
103 reasoning: result.reasoning.clone(),
104 model_version: Some(self.ai_client.model().to_string()),
105 });
106 } else {
107 validated.push(Self::fallback_validation(finding));
109 }
110 }
111 Ok(validated)
112 }
113 Err(e) => {
114 tracing::warn!(error = %e, "LLM validation failed, using pattern confidence");
116 Ok(findings.iter().map(Self::fallback_validation).collect())
117 }
118 }
119 }
120
121 fn build_system_prompt() -> String {
123 r#"You are a security code reviewer. Analyze the provided security findings and determine if they are real vulnerabilities or false positives.
124
125Your response MUST be valid JSON with this exact schema:
126{
127 "results": [
128 {
129 "index": 0,
130 "is_valid": true,
131 "reasoning": "Brief explanation of why this is/isn't a real issue"
132 }
133 ]
134}
135
136Guidelines:
137- index: The 0-based index of the finding in the batch
138- is_valid: true if this is a real security issue, false if it's a false positive
139- reasoning: 1-2 sentence explanation of your decision
140
141Consider:
1421. Context: Is the code actually vulnerable in its usage context?
1432. False positives: Test data, comments, documentation, or safe patterns?
1443. Severity: Does the finding match the claimed severity?
1454. Mitigation: Are there compensating controls in place?
146
147Be conservative: when in doubt, mark as valid to avoid missing real issues."#
148 .to_string()
149 }
150
151 fn build_batch_validation_prompt(
153 findings: &[Finding],
154 file_contents: &std::collections::HashMap<String, String>,
155 ) -> String {
156 use std::fmt::Write;
157
158 let mut prompt = String::new();
159 prompt.push_str("Analyze these security findings:\n\n");
160
161 for (i, finding) in findings.iter().enumerate() {
162 let _ = writeln!(prompt, "Finding {i}:");
163 let _ = writeln!(prompt, " Pattern: {}", finding.pattern_id);
164 let _ = writeln!(prompt, " Description: {}", finding.description);
165 let _ = writeln!(
166 prompt,
167 " Severity: {:?}, Confidence: {:?}",
168 finding.severity, finding.confidence
169 );
170 let _ = writeln!(
171 prompt,
172 " File: {}:{}",
173 finding.file_path, finding.line_number
174 );
175 let _ = writeln!(prompt, " Matched: {}", finding.matched_text);
176
177 if let Some(snippet) =
179 extract_snippet(file_contents.get(&finding.file_path), finding.line_number)
180 {
181 let _ = writeln!(prompt, " Context:\n{snippet}");
182 }
183
184 prompt.push('\n');
185 }
186
187 prompt
188 }
189
190 async fn send_and_parse(
192 &self,
193 request: &ChatCompletionRequest,
194 ) -> Result<Vec<ValidationResult>> {
195 let completion = self.ai_client.send_request_inner(request).await?;
197
198 let content = completion
200 .choices
201 .first()
202 .map(|c| c.message.content.clone())
203 .context("No response from AI model")?;
204
205 let response: ValidationResponse = serde_json::from_str(&content)
207 .context("Failed to parse validation response as JSON")?;
208
209 Ok(response.results)
210 }
211
212 fn fallback_validation(finding: &Finding) -> ValidatedFinding {
214 use super::types::Confidence;
215
216 let is_valid = matches!(finding.confidence, Confidence::High | Confidence::Medium);
217 let reasoning = format!(
218 "LLM validation unavailable, using pattern confidence: {:?}",
219 finding.confidence
220 );
221
222 ValidatedFinding {
223 finding: finding.clone(),
224 is_valid,
225 reasoning,
226 model_version: None,
227 }
228 }
229}
230
231fn extract_snippet(content: Option<&String>, line_number: usize) -> Option<String> {
242 use std::fmt::Write;
243
244 let content = content?;
245 let lines: Vec<&str> = content.lines().collect();
246
247 if line_number == 0 || line_number > lines.len() {
248 return None;
249 }
250
251 let target_idx = line_number - 1;
253 let start = target_idx.saturating_sub(CONTEXT_LINES);
254 let end = (target_idx + CONTEXT_LINES + 1).min(lines.len());
255
256 let mut snippet = String::new();
257 for (i, line) in lines[start..end].iter().enumerate() {
258 let line_num = start + i + 1;
259 let marker = if line_num == line_number { ">" } else { " " };
260 let _ = writeln!(snippet, "{marker} {line_num:4} | {line}");
261 }
262
263 Some(snippet)
264}
265
266#[cfg(test)]
267mod tests {
268 use super::*;
269 use crate::security::types::{Confidence, Severity};
270
271 #[test]
272 fn test_extract_snippet_with_context() {
273 let content = "line 1\nline 2\nline 3\nline 4\nline 5\nline 6\nline 7\n".to_string();
274 let snippet = extract_snippet(Some(&content), 4);
275
276 assert!(snippet.is_some());
277 let snippet = snippet.unwrap();
278 assert!(snippet.contains("> 4 | line 4"));
279 assert!(snippet.contains(" 1 | line 1"));
280 assert!(snippet.contains(" 7 | line 7"));
281 }
282
283 #[test]
284 fn test_extract_snippet_at_start() {
285 let content = "line 1\nline 2\nline 3\n".to_string();
286 let snippet = extract_snippet(Some(&content), 1);
287
288 assert!(snippet.is_some());
289 let snippet = snippet.unwrap();
290 assert!(snippet.contains("> 1 | line 1"));
291 assert!(!snippet.contains(" 0 |"));
292 }
293
294 #[test]
295 fn test_extract_snippet_at_end() {
296 let content = "line 1\nline 2\nline 3\n".to_string();
297 let snippet = extract_snippet(Some(&content), 3);
298
299 assert!(snippet.is_some());
300 let snippet = snippet.unwrap();
301 assert!(snippet.contains("> 3 | line 3"));
302 }
303
304 #[test]
305 fn test_extract_snippet_invalid_line() {
306 let content = "line 1\nline 2\n".to_string();
307 let snippet = extract_snippet(Some(&content), 10);
308
309 assert!(snippet.is_none());
310 }
311
312 #[test]
313 fn test_fallback_validation_high_confidence() {
314 let finding = Finding {
315 pattern_id: "test-pattern".to_string(),
316 description: "Test finding".to_string(),
317 severity: Severity::High,
318 confidence: Confidence::High,
319 file_path: "test.rs".to_string(),
320 line_number: 1,
321 matched_text: "test".to_string(),
322 cwe: None,
323 };
324
325 let validated = SecurityValidator::fallback_validation(&finding);
326 assert!(validated.is_valid);
327 assert!(validated.reasoning.contains("High"));
328 }
329
330 #[test]
331 fn test_fallback_validation_low_confidence() {
332 let finding = Finding {
333 pattern_id: "test-pattern".to_string(),
334 description: "Test finding".to_string(),
335 severity: Severity::High,
336 confidence: Confidence::Low,
337 file_path: "test.rs".to_string(),
338 line_number: 1,
339 matched_text: "test".to_string(),
340 cwe: None,
341 };
342
343 let validated = SecurityValidator::fallback_validation(&finding);
344 assert!(!validated.is_valid);
345 assert!(validated.reasoning.contains("Low"));
346 }
347
348 #[test]
349 fn test_build_system_prompt() {
350 let prompt = SecurityValidator::build_system_prompt();
351 assert!(prompt.contains("security code reviewer"));
352 assert!(prompt.contains("\"results\""));
353 assert!(prompt.contains("\"index\""));
354 assert!(prompt.contains("\"is_valid\""));
355 assert!(prompt.contains("\"reasoning\""));
356 }
357
358 #[test]
359 fn test_parse_validation_response() {
360 let json = r#"{
361 "results": [
362 {
363 "index": 0,
364 "is_valid": true,
365 "reasoning": "This is a real vulnerability"
366 },
367 {
368 "index": 1,
369 "is_valid": false,
370 "reasoning": "This is test data"
371 }
372 ]
373 }"#;
374
375 let response: ValidationResponse = serde_json::from_str(json).unwrap();
376 assert_eq!(response.results.len(), 2);
377 assert_eq!(response.results[0].index, 0);
378 assert!(response.results[0].is_valid);
379 assert_eq!(response.results[1].index, 1);
380 assert!(!response.results[1].is_valid);
381 }
382}