Skip to main content

aptu_core/security/
validator.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! LLM-based validation for security findings.
4//!
5//! Provides batched validation of security findings using AI to reduce false positives.
6//! Batches 3-5 findings per LLM call for efficiency, with fallback to pattern confidence
7//! on parsing errors.
8
9use anyhow::{Context, Result};
10use tracing::instrument;
11
12use super::types::{Finding, ValidatedFinding, ValidationResult};
13use crate::ai::client::AiClient;
14use crate::ai::provider::AiProvider;
15use crate::ai::types::{ChatCompletionRequest, ChatMessage, ResponseFormat};
16
17/// Maximum lines of context to extract around a finding.
18const CONTEXT_LINES: usize = 10;
19
20/// Internal response structure for LLM validation.
21#[derive(serde::Deserialize)]
22struct ValidationResponse {
23    results: Vec<ValidationResult>,
24}
25
26/// Security finding validator using LLM.
27///
28/// Validates security findings in batches to reduce false positives.
29/// Falls back to pattern confidence if LLM validation fails.
30#[derive(Debug)]
31pub struct SecurityValidator {
32    /// AI client for LLM calls.
33    ai_client: AiClient,
34}
35
36impl SecurityValidator {
37    /// Creates a new security validator.
38    ///
39    /// # Arguments
40    ///
41    /// * `ai_client` - AI client configured for validation
42    pub fn new(ai_client: AiClient) -> Self {
43        Self { ai_client }
44    }
45
46    /// Validates a batch of security findings using LLM.
47    ///
48    /// Sends up to `BATCH_SIZE` findings to the LLM for validation.
49    /// Falls back to pattern confidence if LLM response is malformed.
50    ///
51    /// # Arguments
52    ///
53    /// * `findings` - Security findings to validate
54    /// * `file_contents` - Map of file paths to their contents for context extraction
55    ///
56    /// # Returns
57    ///
58    /// Vector of validated findings with LLM reasoning
59    #[instrument(skip(self, findings, file_contents), fields(count = findings.len()))]
60    pub async fn validate_findings_batch(
61        &self,
62        findings: &[Finding],
63        file_contents: &std::collections::HashMap<String, String>,
64    ) -> Result<Vec<ValidatedFinding>> {
65        if findings.is_empty() {
66            return Ok(Vec::new());
67        }
68
69        // Build validation prompt
70        let prompt = Self::build_batch_validation_prompt(findings, file_contents);
71
72        // Build request
73        let request = ChatCompletionRequest {
74            model: self.ai_client.model().to_string(),
75            messages: vec![
76                ChatMessage {
77                    role: "system".to_string(),
78                    content: Some(Self::build_system_prompt()),
79                    reasoning: None,
80                    cache_control: None,
81                },
82                ChatMessage {
83                    role: "user".to_string(),
84                    content: Some(prompt),
85                    reasoning: None,
86                    cache_control: None,
87                },
88            ],
89            response_format: Some(ResponseFormat {
90                format_type: "json_object".to_string(),
91                json_schema: None,
92            }),
93            max_tokens: Some(self.ai_client.max_tokens()),
94            temperature: Some(0.3),
95        };
96
97        // Send request and parse response
98        match self.send_and_parse(&request).await {
99            Ok(results) => {
100                // Map results to validated findings
101                let mut validated = Vec::new();
102                for (i, finding) in findings.iter().enumerate() {
103                    if let Some(result) = results.iter().find(|r| r.index == i) {
104                        validated.push(ValidatedFinding {
105                            finding: finding.clone(),
106                            is_valid: result.is_valid,
107                            reasoning: result.reasoning.clone(),
108                            model_version: Some(self.ai_client.model().to_string()),
109                        });
110                    } else {
111                        // Fallback: use pattern confidence
112                        validated.push(Self::fallback_validation(finding));
113                    }
114                }
115                Ok(validated)
116            }
117            Err(e) => {
118                // Fallback: use pattern confidence for all findings
119                tracing::warn!(error = %e, "LLM validation failed, using pattern confidence");
120                Ok(findings.iter().map(Self::fallback_validation).collect())
121            }
122        }
123    }
124
125    /// Builds the system prompt for validation.
126    fn build_system_prompt() -> String {
127        r#"You are a security code reviewer. Analyze the provided security findings and determine if they are real vulnerabilities or false positives.
128
129Your response MUST be valid JSON with this exact schema:
130{
131  "results": [
132    {
133      "index": 0,
134      "is_valid": true,
135      "reasoning": "Brief explanation of why this is/isn't a real issue"
136    }
137  ]
138}
139
140Guidelines:
141- index: The 0-based index of the finding in the batch
142- is_valid: true if this is a real security issue, false if it's a false positive
143- reasoning: 1-2 sentence explanation of your decision
144
145Consider:
1461. Context: Is the code actually vulnerable in its usage context?
1472. False positives: Test data, comments, documentation, or safe patterns?
1483. Severity: Does the finding match the claimed severity?
1494. Mitigation: Are there compensating controls in place?
150
151Be conservative: when in doubt, mark as valid to avoid missing real issues."#
152            .to_string()
153    }
154
155    /// Builds the validation prompt for a batch of findings.
156    fn build_batch_validation_prompt(
157        findings: &[Finding],
158        file_contents: &std::collections::HashMap<String, String>,
159    ) -> String {
160        use std::fmt::Write;
161
162        let mut prompt = String::new();
163        prompt.push_str("Analyze these security findings:\n\n");
164
165        for (i, finding) in findings.iter().enumerate() {
166            let _ = writeln!(prompt, "Finding {i}:");
167            let _ = writeln!(prompt, "  Pattern: {}", finding.pattern_id);
168            let _ = writeln!(prompt, "  Description: {}", finding.description);
169            let _ = writeln!(
170                prompt,
171                "  Severity: {:?}, Confidence: {:?}",
172                finding.severity, finding.confidence
173            );
174            let _ = writeln!(
175                prompt,
176                "  File: {}:{}",
177                finding.file_path, finding.line_number
178            );
179            let _ = writeln!(prompt, "  Matched: {}", finding.matched_text);
180
181            // Extract context snippet
182            if let Some(snippet) =
183                extract_snippet(file_contents.get(&finding.file_path), finding.line_number)
184            {
185                let _ = writeln!(prompt, "  Context:\n{snippet}");
186            }
187
188            prompt.push('\n');
189        }
190
191        prompt
192    }
193
194    /// Sends a validation request and parses the response.
195    async fn send_and_parse(
196        &self,
197        request: &ChatCompletionRequest,
198    ) -> Result<Vec<ValidationResult>> {
199        // Send request using AiProvider trait
200        let completion = self.ai_client.send_request_inner(request).await?;
201
202        // Extract message content
203        let content = completion
204            .choices
205            .first()
206            .and_then(|c| {
207                c.message
208                    .content
209                    .clone()
210                    .or_else(|| c.message.reasoning.clone())
211            })
212            .context("No response from AI model")?;
213
214        // Parse JSON response
215        let response: ValidationResponse = serde_json::from_str(&content)
216            .context("Failed to parse validation response as JSON")?;
217
218        Ok(response.results)
219    }
220
221    /// Creates a fallback validated finding using pattern confidence.
222    fn fallback_validation(finding: &Finding) -> ValidatedFinding {
223        use super::types::Confidence;
224
225        let is_valid = matches!(finding.confidence, Confidence::High | Confidence::Medium);
226        let reasoning = format!(
227            "LLM validation unavailable, using pattern confidence: {:?}",
228            finding.confidence
229        );
230
231        ValidatedFinding {
232            finding: finding.clone(),
233            is_valid,
234            reasoning,
235            model_version: None,
236        }
237    }
238}
239
240/// Extracts a code snippet with context around a line number.
241///
242/// # Arguments
243///
244/// * `content` - File content
245/// * `line_number` - Target line number (1-indexed)
246///
247/// # Returns
248///
249/// Code snippet with up to `CONTEXT_LINES` before and after the target line
250fn extract_snippet(content: Option<&String>, line_number: usize) -> Option<String> {
251    use std::fmt::Write;
252
253    let content = content?;
254    let lines: Vec<&str> = content.lines().collect();
255
256    if line_number == 0 || line_number > lines.len() {
257        return None;
258    }
259
260    // Calculate range (1-indexed to 0-indexed)
261    let target_idx = line_number - 1;
262    let start = target_idx.saturating_sub(CONTEXT_LINES);
263    let end = (target_idx + CONTEXT_LINES + 1).min(lines.len());
264
265    let mut snippet = String::new();
266    for (i, line) in lines[start..end].iter().enumerate() {
267        let line_num = start + i + 1;
268        let marker = if line_num == line_number { ">" } else { " " };
269        let _ = writeln!(snippet, "{marker} {line_num:4} | {line}");
270    }
271
272    Some(snippet)
273}
274
275#[cfg(test)]
276mod tests {
277    use super::*;
278    use crate::security::types::{Confidence, Severity};
279
280    #[test]
281    fn test_extract_snippet_with_context() {
282        let content = "line 1\nline 2\nline 3\nline 4\nline 5\nline 6\nline 7\n".to_string();
283        let snippet = extract_snippet(Some(&content), 4);
284
285        assert!(snippet.is_some());
286        let snippet = snippet.unwrap();
287        assert!(snippet.contains(">    4 | line 4"));
288        assert!(snippet.contains("     1 | line 1"));
289        assert!(snippet.contains("     7 | line 7"));
290    }
291
292    #[test]
293    fn test_extract_snippet_at_start() {
294        let content = "line 1\nline 2\nline 3\n".to_string();
295        let snippet = extract_snippet(Some(&content), 1);
296
297        assert!(snippet.is_some());
298        let snippet = snippet.unwrap();
299        assert!(snippet.contains(">    1 | line 1"));
300        assert!(!snippet.contains("     0 |"));
301    }
302
303    #[test]
304    fn test_extract_snippet_at_end() {
305        let content = "line 1\nline 2\nline 3\n".to_string();
306        let snippet = extract_snippet(Some(&content), 3);
307
308        assert!(snippet.is_some());
309        let snippet = snippet.unwrap();
310        assert!(snippet.contains(">    3 | line 3"));
311    }
312
313    #[test]
314    fn test_extract_snippet_invalid_line() {
315        let content = "line 1\nline 2\n".to_string();
316        let snippet = extract_snippet(Some(&content), 10);
317
318        assert!(snippet.is_none());
319    }
320
321    #[test]
322    fn test_fallback_validation_high_confidence() {
323        let finding = Finding {
324            pattern_id: "test-pattern".to_string(),
325            description: "Test finding".to_string(),
326            severity: Severity::High,
327            confidence: Confidence::High,
328            file_path: "test.rs".to_string(),
329            line_number: 1,
330            matched_text: "test".to_string(),
331            cwe: None,
332        };
333
334        let validated = SecurityValidator::fallback_validation(&finding);
335        assert!(validated.is_valid);
336        assert!(validated.reasoning.contains("High"));
337    }
338
339    #[test]
340    fn test_fallback_validation_low_confidence() {
341        let finding = Finding {
342            pattern_id: "test-pattern".to_string(),
343            description: "Test finding".to_string(),
344            severity: Severity::High,
345            confidence: Confidence::Low,
346            file_path: "test.rs".to_string(),
347            line_number: 1,
348            matched_text: "test".to_string(),
349            cwe: None,
350        };
351
352        let validated = SecurityValidator::fallback_validation(&finding);
353        assert!(!validated.is_valid);
354        assert!(validated.reasoning.contains("Low"));
355    }
356
357    #[test]
358    fn test_build_system_prompt() {
359        let prompt = SecurityValidator::build_system_prompt();
360        assert!(prompt.contains("security code reviewer"));
361        assert!(prompt.contains("\"results\""));
362        assert!(prompt.contains("\"index\""));
363        assert!(prompt.contains("\"is_valid\""));
364        assert!(prompt.contains("\"reasoning\""));
365    }
366
367    #[test]
368    fn test_parse_validation_response() {
369        let json = r#"{
370            "results": [
371                {
372                    "index": 0,
373                    "is_valid": true,
374                    "reasoning": "This is a real vulnerability"
375                },
376                {
377                    "index": 1,
378                    "is_valid": false,
379                    "reasoning": "This is test data"
380                }
381            ]
382        }"#;
383
384        let response: ValidationResponse = serde_json::from_str(json).unwrap();
385        assert_eq!(response.results.len(), 2);
386        assert_eq!(response.results[0].index, 0);
387        assert!(response.results[0].is_valid);
388        assert_eq!(response.results[1].index, 1);
389        assert!(!response.results[1].is_valid);
390    }
391}