1use anyhow::{Context, Result};
10use tracing::instrument;
11
12use super::types::{Finding, ValidatedFinding, ValidationResult};
13use crate::ai::client::AiClient;
14use crate::ai::provider::AiProvider;
15use crate::ai::types::{ChatCompletionRequest, ChatMessage, ResponseFormat};
16
17const CONTEXT_LINES: usize = 10;
19
20#[derive(serde::Deserialize)]
22struct ValidationResponse {
23 results: Vec<ValidationResult>,
24}
25
26#[derive(Debug)]
31pub struct SecurityValidator {
32 ai_client: AiClient,
34}
35
36impl SecurityValidator {
37 pub fn new(ai_client: AiClient) -> Self {
43 Self { ai_client }
44 }
45
46 #[instrument(skip(self, findings, file_contents), fields(count = findings.len()))]
60 pub async fn validate_findings_batch(
61 &self,
62 findings: &[Finding],
63 file_contents: &std::collections::HashMap<String, String>,
64 ) -> Result<Vec<ValidatedFinding>> {
65 if findings.is_empty() {
66 return Ok(Vec::new());
67 }
68
69 let prompt = Self::build_batch_validation_prompt(findings, file_contents);
71
72 let request = ChatCompletionRequest {
74 model: self.ai_client.model().to_string(),
75 messages: vec![
76 ChatMessage {
77 role: "system".to_string(),
78 content: Some(Self::build_system_prompt()),
79 reasoning: None,
80 cache_control: None,
81 },
82 ChatMessage {
83 role: "user".to_string(),
84 content: Some(prompt),
85 reasoning: None,
86 cache_control: None,
87 },
88 ],
89 response_format: Some(ResponseFormat {
90 format_type: "json_object".to_string(),
91 json_schema: None,
92 }),
93 max_tokens: Some(self.ai_client.max_tokens()),
94 temperature: Some(0.3),
95 };
96
97 match self.send_and_parse(&request).await {
99 Ok(results) => {
100 let mut validated = Vec::new();
102 for (i, finding) in findings.iter().enumerate() {
103 if let Some(result) = results.iter().find(|r| r.index == i) {
104 validated.push(ValidatedFinding {
105 finding: finding.clone(),
106 is_valid: result.is_valid,
107 reasoning: result.reasoning.clone(),
108 model_version: Some(self.ai_client.model().to_string()),
109 });
110 } else {
111 validated.push(Self::fallback_validation(finding));
113 }
114 }
115 Ok(validated)
116 }
117 Err(e) => {
118 tracing::warn!(error = %e, "LLM validation failed, using pattern confidence");
120 Ok(findings.iter().map(Self::fallback_validation).collect())
121 }
122 }
123 }
124
125 fn build_system_prompt() -> String {
127 r#"You are a security code reviewer. Analyze the provided security findings and determine if they are real vulnerabilities or false positives.
128
129Your response MUST be valid JSON with this exact schema:
130{
131 "results": [
132 {
133 "index": 0,
134 "is_valid": true,
135 "reasoning": "Brief explanation of why this is/isn't a real issue"
136 }
137 ]
138}
139
140Guidelines:
141- index: The 0-based index of the finding in the batch
142- is_valid: true if this is a real security issue, false if it's a false positive
143- reasoning: 1-2 sentence explanation of your decision
144
145Consider:
1461. Context: Is the code actually vulnerable in its usage context?
1472. False positives: Test data, comments, documentation, or safe patterns?
1483. Severity: Does the finding match the claimed severity?
1494. Mitigation: Are there compensating controls in place?
150
151Be conservative: when in doubt, mark as valid to avoid missing real issues."#
152 .to_string()
153 }
154
155 fn build_batch_validation_prompt(
157 findings: &[Finding],
158 file_contents: &std::collections::HashMap<String, String>,
159 ) -> String {
160 use std::fmt::Write;
161
162 let mut prompt = String::new();
163 prompt.push_str("Analyze these security findings:\n\n");
164
165 for (i, finding) in findings.iter().enumerate() {
166 let _ = writeln!(prompt, "Finding {i}:");
167 let _ = writeln!(prompt, " Pattern: {}", finding.pattern_id);
168 let _ = writeln!(prompt, " Description: {}", finding.description);
169 let _ = writeln!(
170 prompt,
171 " Severity: {:?}, Confidence: {:?}",
172 finding.severity, finding.confidence
173 );
174 let _ = writeln!(
175 prompt,
176 " File: {}:{}",
177 finding.file_path, finding.line_number
178 );
179 let _ = writeln!(prompt, " Matched: {}", finding.matched_text);
180
181 if let Some(snippet) =
183 extract_snippet(file_contents.get(&finding.file_path), finding.line_number)
184 {
185 let _ = writeln!(prompt, " Context:\n{snippet}");
186 }
187
188 prompt.push('\n');
189 }
190
191 prompt
192 }
193
194 async fn send_and_parse(
196 &self,
197 request: &ChatCompletionRequest,
198 ) -> Result<Vec<ValidationResult>> {
199 let completion = self.ai_client.send_request_inner(request).await?;
201
202 let content = completion
204 .choices
205 .first()
206 .and_then(|c| {
207 c.message
208 .content
209 .clone()
210 .or_else(|| c.message.reasoning.clone())
211 })
212 .context("No response from AI model")?;
213
214 let response: ValidationResponse = serde_json::from_str(&content)
216 .context("Failed to parse validation response as JSON")?;
217
218 Ok(response.results)
219 }
220
221 fn fallback_validation(finding: &Finding) -> ValidatedFinding {
223 use super::types::Confidence;
224
225 let is_valid = matches!(finding.confidence, Confidence::High | Confidence::Medium);
226 let reasoning = format!(
227 "LLM validation unavailable, using pattern confidence: {:?}",
228 finding.confidence
229 );
230
231 ValidatedFinding {
232 finding: finding.clone(),
233 is_valid,
234 reasoning,
235 model_version: None,
236 }
237 }
238}
239
240fn extract_snippet(content: Option<&String>, line_number: usize) -> Option<String> {
251 use std::fmt::Write;
252
253 let content = content?;
254 let lines: Vec<&str> = content.lines().collect();
255
256 if line_number == 0 || line_number > lines.len() {
257 return None;
258 }
259
260 let target_idx = line_number - 1;
262 let start = target_idx.saturating_sub(CONTEXT_LINES);
263 let end = (target_idx + CONTEXT_LINES + 1).min(lines.len());
264
265 let mut snippet = String::new();
266 for (i, line) in lines[start..end].iter().enumerate() {
267 let line_num = start + i + 1;
268 let marker = if line_num == line_number { ">" } else { " " };
269 let _ = writeln!(snippet, "{marker} {line_num:4} | {line}");
270 }
271
272 Some(snippet)
273}
274
275#[cfg(test)]
276mod tests {
277 use super::*;
278 use crate::security::types::{Confidence, Severity};
279
280 #[test]
281 fn test_extract_snippet_with_context() {
282 let content = "line 1\nline 2\nline 3\nline 4\nline 5\nline 6\nline 7\n".to_string();
283 let snippet = extract_snippet(Some(&content), 4);
284
285 assert!(snippet.is_some());
286 let snippet = snippet.unwrap();
287 assert!(snippet.contains("> 4 | line 4"));
288 assert!(snippet.contains(" 1 | line 1"));
289 assert!(snippet.contains(" 7 | line 7"));
290 }
291
292 #[test]
293 fn test_extract_snippet_at_start() {
294 let content = "line 1\nline 2\nline 3\n".to_string();
295 let snippet = extract_snippet(Some(&content), 1);
296
297 assert!(snippet.is_some());
298 let snippet = snippet.unwrap();
299 assert!(snippet.contains("> 1 | line 1"));
300 assert!(!snippet.contains(" 0 |"));
301 }
302
303 #[test]
304 fn test_extract_snippet_at_end() {
305 let content = "line 1\nline 2\nline 3\n".to_string();
306 let snippet = extract_snippet(Some(&content), 3);
307
308 assert!(snippet.is_some());
309 let snippet = snippet.unwrap();
310 assert!(snippet.contains("> 3 | line 3"));
311 }
312
313 #[test]
314 fn test_extract_snippet_invalid_line() {
315 let content = "line 1\nline 2\n".to_string();
316 let snippet = extract_snippet(Some(&content), 10);
317
318 assert!(snippet.is_none());
319 }
320
321 #[test]
322 fn test_fallback_validation_high_confidence() {
323 let finding = Finding {
324 pattern_id: "test-pattern".to_string(),
325 description: "Test finding".to_string(),
326 severity: Severity::High,
327 confidence: Confidence::High,
328 file_path: "test.rs".to_string(),
329 line_number: 1,
330 matched_text: "test".to_string(),
331 cwe: None,
332 };
333
334 let validated = SecurityValidator::fallback_validation(&finding);
335 assert!(validated.is_valid);
336 assert!(validated.reasoning.contains("High"));
337 }
338
339 #[test]
340 fn test_fallback_validation_low_confidence() {
341 let finding = Finding {
342 pattern_id: "test-pattern".to_string(),
343 description: "Test finding".to_string(),
344 severity: Severity::High,
345 confidence: Confidence::Low,
346 file_path: "test.rs".to_string(),
347 line_number: 1,
348 matched_text: "test".to_string(),
349 cwe: None,
350 };
351
352 let validated = SecurityValidator::fallback_validation(&finding);
353 assert!(!validated.is_valid);
354 assert!(validated.reasoning.contains("Low"));
355 }
356
357 #[test]
358 fn test_build_system_prompt() {
359 let prompt = SecurityValidator::build_system_prompt();
360 assert!(prompt.contains("security code reviewer"));
361 assert!(prompt.contains("\"results\""));
362 assert!(prompt.contains("\"index\""));
363 assert!(prompt.contains("\"is_valid\""));
364 assert!(prompt.contains("\"reasoning\""));
365 }
366
367 #[test]
368 fn test_parse_validation_response() {
369 let json = r#"{
370 "results": [
371 {
372 "index": 0,
373 "is_valid": true,
374 "reasoning": "This is a real vulnerability"
375 },
376 {
377 "index": 1,
378 "is_valid": false,
379 "reasoning": "This is test data"
380 }
381 ]
382 }"#;
383
384 let response: ValidationResponse = serde_json::from_str(json).unwrap();
385 assert_eq!(response.results.len(), 2);
386 assert_eq!(response.results[0].index, 0);
387 assert!(response.results[0].is_valid);
388 assert_eq!(response.results[1].index, 1);
389 assert!(!response.results[1].is_valid);
390 }
391}