Skip to main content

aptu_core/security/
validator.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! LLM-based validation for security findings.
4//!
5//! Provides batched validation of security findings using AI to reduce false positives.
6//! Batches 3-5 findings per LLM call for efficiency, with fallback to pattern confidence
7//! on parsing errors.
8
9use anyhow::{Context, Result};
10use tracing::instrument;
11
12use super::types::{Finding, ValidatedFinding, ValidationResult};
13use crate::ai::client::AiClient;
14use crate::ai::provider::AiProvider;
15use crate::ai::types::{ChatCompletionRequest, ChatMessage, ResponseFormat};
16
17/// Maximum lines of context to extract around a finding.
18const CONTEXT_LINES: usize = 10;
19
20/// Internal response structure for LLM validation.
21#[derive(serde::Deserialize)]
22struct ValidationResponse {
23    results: Vec<ValidationResult>,
24}
25
26/// Security finding validator using LLM.
27///
28/// Validates security findings in batches to reduce false positives.
29/// Falls back to pattern confidence if LLM validation fails.
30#[derive(Debug)]
31pub struct SecurityValidator {
32    /// AI client for LLM calls.
33    ai_client: AiClient,
34}
35
36impl SecurityValidator {
37    /// Creates a new security validator.
38    ///
39    /// # Arguments
40    ///
41    /// * `ai_client` - AI client configured for validation
42    pub fn new(ai_client: AiClient) -> Self {
43        Self { ai_client }
44    }
45
46    /// Validates a batch of security findings using LLM.
47    ///
48    /// Sends up to `BATCH_SIZE` findings to the LLM for validation.
49    /// Falls back to pattern confidence if LLM response is malformed.
50    ///
51    /// # Arguments
52    ///
53    /// * `findings` - Security findings to validate
54    /// * `file_contents` - Map of file paths to their contents for context extraction
55    ///
56    /// # Returns
57    ///
58    /// Vector of validated findings with LLM reasoning
59    #[instrument(skip(self, findings, file_contents), fields(count = findings.len()))]
60    pub async fn validate_findings_batch(
61        &self,
62        findings: &[Finding],
63        file_contents: &std::collections::HashMap<String, String>,
64    ) -> Result<Vec<ValidatedFinding>> {
65        if findings.is_empty() {
66            return Ok(Vec::new());
67        }
68
69        // Build validation prompt
70        let prompt = Self::build_batch_validation_prompt(findings, file_contents);
71
72        // Build request
73        let request = ChatCompletionRequest {
74            model: self.ai_client.model().to_string(),
75            messages: vec![
76                ChatMessage {
77                    role: "system".to_string(),
78                    content: Some(Self::build_system_prompt()),
79                    reasoning: None,
80                },
81                ChatMessage {
82                    role: "user".to_string(),
83                    content: Some(prompt),
84                    reasoning: None,
85                },
86            ],
87            response_format: Some(ResponseFormat {
88                format_type: "json_object".to_string(),
89                json_schema: None,
90            }),
91            max_tokens: Some(self.ai_client.max_tokens()),
92            temperature: Some(0.3),
93        };
94
95        // Send request and parse response
96        match self.send_and_parse(&request).await {
97            Ok(results) => {
98                // Map results to validated findings
99                let mut validated = Vec::new();
100                for (i, finding) in findings.iter().enumerate() {
101                    if let Some(result) = results.iter().find(|r| r.index == i) {
102                        validated.push(ValidatedFinding {
103                            finding: finding.clone(),
104                            is_valid: result.is_valid,
105                            reasoning: result.reasoning.clone(),
106                            model_version: Some(self.ai_client.model().to_string()),
107                        });
108                    } else {
109                        // Fallback: use pattern confidence
110                        validated.push(Self::fallback_validation(finding));
111                    }
112                }
113                Ok(validated)
114            }
115            Err(e) => {
116                // Fallback: use pattern confidence for all findings
117                tracing::warn!(error = %e, "LLM validation failed, using pattern confidence");
118                Ok(findings.iter().map(Self::fallback_validation).collect())
119            }
120        }
121    }
122
123    /// Builds the system prompt for validation.
124    fn build_system_prompt() -> String {
125        r#"You are a security code reviewer. Analyze the provided security findings and determine if they are real vulnerabilities or false positives.
126
127Your response MUST be valid JSON with this exact schema:
128{
129  "results": [
130    {
131      "index": 0,
132      "is_valid": true,
133      "reasoning": "Brief explanation of why this is/isn't a real issue"
134    }
135  ]
136}
137
138Guidelines:
139- index: The 0-based index of the finding in the batch
140- is_valid: true if this is a real security issue, false if it's a false positive
141- reasoning: 1-2 sentence explanation of your decision
142
143Consider:
1441. Context: Is the code actually vulnerable in its usage context?
1452. False positives: Test data, comments, documentation, or safe patterns?
1463. Severity: Does the finding match the claimed severity?
1474. Mitigation: Are there compensating controls in place?
148
149Be conservative: when in doubt, mark as valid to avoid missing real issues."#
150            .to_string()
151    }
152
153    /// Builds the validation prompt for a batch of findings.
154    fn build_batch_validation_prompt(
155        findings: &[Finding],
156        file_contents: &std::collections::HashMap<String, String>,
157    ) -> String {
158        use std::fmt::Write;
159
160        let mut prompt = String::new();
161        prompt.push_str("Analyze these security findings:\n\n");
162
163        for (i, finding) in findings.iter().enumerate() {
164            let _ = writeln!(prompt, "Finding {i}:");
165            let _ = writeln!(prompt, "  Pattern: {}", finding.pattern_id);
166            let _ = writeln!(prompt, "  Description: {}", finding.description);
167            let _ = writeln!(
168                prompt,
169                "  Severity: {:?}, Confidence: {:?}",
170                finding.severity, finding.confidence
171            );
172            let _ = writeln!(
173                prompt,
174                "  File: {}:{}",
175                finding.file_path, finding.line_number
176            );
177            let _ = writeln!(prompt, "  Matched: {}", finding.matched_text);
178
179            // Extract context snippet
180            if let Some(snippet) =
181                extract_snippet(file_contents.get(&finding.file_path), finding.line_number)
182            {
183                let _ = writeln!(prompt, "  Context:\n{snippet}");
184            }
185
186            prompt.push('\n');
187        }
188
189        prompt
190    }
191
192    /// Sends a validation request and parses the response.
193    async fn send_and_parse(
194        &self,
195        request: &ChatCompletionRequest,
196    ) -> Result<Vec<ValidationResult>> {
197        // Send request using AiProvider trait
198        let completion = self.ai_client.send_request_inner(request).await?;
199
200        // Extract message content
201        let content = completion
202            .choices
203            .first()
204            .and_then(|c| {
205                c.message
206                    .content
207                    .clone()
208                    .or_else(|| c.message.reasoning.clone())
209            })
210            .context("No response from AI model")?;
211
212        // Parse JSON response
213        let response: ValidationResponse = serde_json::from_str(&content)
214            .context("Failed to parse validation response as JSON")?;
215
216        Ok(response.results)
217    }
218
219    /// Creates a fallback validated finding using pattern confidence.
220    fn fallback_validation(finding: &Finding) -> ValidatedFinding {
221        use super::types::Confidence;
222
223        let is_valid = matches!(finding.confidence, Confidence::High | Confidence::Medium);
224        let reasoning = format!(
225            "LLM validation unavailable, using pattern confidence: {:?}",
226            finding.confidence
227        );
228
229        ValidatedFinding {
230            finding: finding.clone(),
231            is_valid,
232            reasoning,
233            model_version: None,
234        }
235    }
236}
237
238/// Extracts a code snippet with context around a line number.
239///
240/// # Arguments
241///
242/// * `content` - File content
243/// * `line_number` - Target line number (1-indexed)
244///
245/// # Returns
246///
247/// Code snippet with up to `CONTEXT_LINES` before and after the target line
248fn extract_snippet(content: Option<&String>, line_number: usize) -> Option<String> {
249    use std::fmt::Write;
250
251    let content = content?;
252    let lines: Vec<&str> = content.lines().collect();
253
254    if line_number == 0 || line_number > lines.len() {
255        return None;
256    }
257
258    // Calculate range (1-indexed to 0-indexed)
259    let target_idx = line_number - 1;
260    let start = target_idx.saturating_sub(CONTEXT_LINES);
261    let end = (target_idx + CONTEXT_LINES + 1).min(lines.len());
262
263    let mut snippet = String::new();
264    for (i, line) in lines[start..end].iter().enumerate() {
265        let line_num = start + i + 1;
266        let marker = if line_num == line_number { ">" } else { " " };
267        let _ = writeln!(snippet, "{marker} {line_num:4} | {line}");
268    }
269
270    Some(snippet)
271}
272
273#[cfg(test)]
274mod tests {
275    use super::*;
276    use crate::security::types::{Confidence, Severity};
277
278    #[test]
279    fn test_extract_snippet_with_context() {
280        let content = "line 1\nline 2\nline 3\nline 4\nline 5\nline 6\nline 7\n".to_string();
281        let snippet = extract_snippet(Some(&content), 4);
282
283        assert!(snippet.is_some());
284        let snippet = snippet.unwrap();
285        assert!(snippet.contains(">    4 | line 4"));
286        assert!(snippet.contains("     1 | line 1"));
287        assert!(snippet.contains("     7 | line 7"));
288    }
289
290    #[test]
291    fn test_extract_snippet_at_start() {
292        let content = "line 1\nline 2\nline 3\n".to_string();
293        let snippet = extract_snippet(Some(&content), 1);
294
295        assert!(snippet.is_some());
296        let snippet = snippet.unwrap();
297        assert!(snippet.contains(">    1 | line 1"));
298        assert!(!snippet.contains("     0 |"));
299    }
300
301    #[test]
302    fn test_extract_snippet_at_end() {
303        let content = "line 1\nline 2\nline 3\n".to_string();
304        let snippet = extract_snippet(Some(&content), 3);
305
306        assert!(snippet.is_some());
307        let snippet = snippet.unwrap();
308        assert!(snippet.contains(">    3 | line 3"));
309    }
310
311    #[test]
312    fn test_extract_snippet_invalid_line() {
313        let content = "line 1\nline 2\n".to_string();
314        let snippet = extract_snippet(Some(&content), 10);
315
316        assert!(snippet.is_none());
317    }
318
319    #[test]
320    fn test_fallback_validation_high_confidence() {
321        let finding = Finding {
322            pattern_id: "test-pattern".to_string(),
323            description: "Test finding".to_string(),
324            severity: Severity::High,
325            confidence: Confidence::High,
326            file_path: "test.rs".to_string(),
327            line_number: 1,
328            matched_text: "test".to_string(),
329            cwe: None,
330        };
331
332        let validated = SecurityValidator::fallback_validation(&finding);
333        assert!(validated.is_valid);
334        assert!(validated.reasoning.contains("High"));
335    }
336
337    #[test]
338    fn test_fallback_validation_low_confidence() {
339        let finding = Finding {
340            pattern_id: "test-pattern".to_string(),
341            description: "Test finding".to_string(),
342            severity: Severity::High,
343            confidence: Confidence::Low,
344            file_path: "test.rs".to_string(),
345            line_number: 1,
346            matched_text: "test".to_string(),
347            cwe: None,
348        };
349
350        let validated = SecurityValidator::fallback_validation(&finding);
351        assert!(!validated.is_valid);
352        assert!(validated.reasoning.contains("Low"));
353    }
354
355    #[test]
356    fn test_build_system_prompt() {
357        let prompt = SecurityValidator::build_system_prompt();
358        assert!(prompt.contains("security code reviewer"));
359        assert!(prompt.contains("\"results\""));
360        assert!(prompt.contains("\"index\""));
361        assert!(prompt.contains("\"is_valid\""));
362        assert!(prompt.contains("\"reasoning\""));
363    }
364
365    #[test]
366    fn test_parse_validation_response() {
367        let json = r#"{
368            "results": [
369                {
370                    "index": 0,
371                    "is_valid": true,
372                    "reasoning": "This is a real vulnerability"
373                },
374                {
375                    "index": 1,
376                    "is_valid": false,
377                    "reasoning": "This is test data"
378                }
379            ]
380        }"#;
381
382        let response: ValidationResponse = serde_json::from_str(json).unwrap();
383        assert_eq!(response.results.len(), 2);
384        assert_eq!(response.results[0].index, 0);
385        assert!(response.results[0].is_valid);
386        assert_eq!(response.results[1].index, 1);
387        assert!(!response.results[1].is_valid);
388    }
389}