Skip to main content

aptu_core/security/
validator.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! LLM-based validation for security findings.
4//!
5//! Provides batched validation of security findings using AI to reduce false positives.
6//! Batches 3-5 findings per LLM call for efficiency, with fallback to pattern confidence
7//! on parsing errors.
8
9use anyhow::{Context, Result};
10use tracing::instrument;
11
12use super::types::{Finding, ValidatedFinding, ValidationResult};
13use crate::ai::client::AiClient;
14use crate::ai::provider::AiProvider;
15use crate::ai::types::{ChatCompletionRequest, ChatMessage, ResponseFormat};
16
17/// Maximum lines of context to extract around a finding.
18const CONTEXT_LINES: usize = 10;
19
20/// Internal response structure for LLM validation.
21#[derive(serde::Deserialize)]
22struct ValidationResponse {
23    results: Vec<ValidationResult>,
24}
25
26/// Security finding validator using LLM.
27///
28/// Validates security findings in batches to reduce false positives.
29/// Falls back to pattern confidence if LLM validation fails.
30#[derive(Debug)]
31pub struct SecurityValidator {
32    /// AI client for LLM calls.
33    ai_client: AiClient,
34}
35
36impl SecurityValidator {
37    /// Creates a new security validator.
38    ///
39    /// # Arguments
40    ///
41    /// * `ai_client` - AI client configured for validation
42    pub fn new(ai_client: AiClient) -> Self {
43        Self { ai_client }
44    }
45
46    /// Validates a batch of security findings using LLM.
47    ///
48    /// Sends up to `BATCH_SIZE` findings to the LLM for validation.
49    /// Falls back to pattern confidence if LLM response is malformed.
50    ///
51    /// # Arguments
52    ///
53    /// * `findings` - Security findings to validate
54    /// * `file_contents` - Map of file paths to their contents for context extraction
55    ///
56    /// # Returns
57    ///
58    /// Vector of validated findings with LLM reasoning
59    #[instrument(skip(self, findings, file_contents), fields(count = findings.len()))]
60    pub async fn validate_findings_batch(
61        &self,
62        findings: &[Finding],
63        file_contents: &std::collections::HashMap<String, String>,
64    ) -> Result<Vec<ValidatedFinding>> {
65        if findings.is_empty() {
66            return Ok(Vec::new());
67        }
68
69        // Build validation prompt
70        let prompt = Self::build_batch_validation_prompt(findings, file_contents);
71
72        // Build request
73        let request = ChatCompletionRequest {
74            model: self.ai_client.model().to_string(),
75            messages: vec![
76                ChatMessage {
77                    role: "system".to_string(),
78                    content: Self::build_system_prompt(),
79                },
80                ChatMessage {
81                    role: "user".to_string(),
82                    content: prompt,
83                },
84            ],
85            response_format: Some(ResponseFormat {
86                format_type: "json_object".to_string(),
87                json_schema: None,
88            }),
89            max_tokens: Some(self.ai_client.max_tokens()),
90            temperature: Some(0.3),
91        };
92
93        // Send request and parse response
94        match self.send_and_parse(&request).await {
95            Ok(results) => {
96                // Map results to validated findings
97                let mut validated = Vec::new();
98                for (i, finding) in findings.iter().enumerate() {
99                    if let Some(result) = results.iter().find(|r| r.index == i) {
100                        validated.push(ValidatedFinding {
101                            finding: finding.clone(),
102                            is_valid: result.is_valid,
103                            reasoning: result.reasoning.clone(),
104                            model_version: Some(self.ai_client.model().to_string()),
105                        });
106                    } else {
107                        // Fallback: use pattern confidence
108                        validated.push(Self::fallback_validation(finding));
109                    }
110                }
111                Ok(validated)
112            }
113            Err(e) => {
114                // Fallback: use pattern confidence for all findings
115                tracing::warn!(error = %e, "LLM validation failed, using pattern confidence");
116                Ok(findings.iter().map(Self::fallback_validation).collect())
117            }
118        }
119    }
120
121    /// Builds the system prompt for validation.
122    fn build_system_prompt() -> String {
123        r#"You are a security code reviewer. Analyze the provided security findings and determine if they are real vulnerabilities or false positives.
124
125Your response MUST be valid JSON with this exact schema:
126{
127  "results": [
128    {
129      "index": 0,
130      "is_valid": true,
131      "reasoning": "Brief explanation of why this is/isn't a real issue"
132    }
133  ]
134}
135
136Guidelines:
137- index: The 0-based index of the finding in the batch
138- is_valid: true if this is a real security issue, false if it's a false positive
139- reasoning: 1-2 sentence explanation of your decision
140
141Consider:
1421. Context: Is the code actually vulnerable in its usage context?
1432. False positives: Test data, comments, documentation, or safe patterns?
1443. Severity: Does the finding match the claimed severity?
1454. Mitigation: Are there compensating controls in place?
146
147Be conservative: when in doubt, mark as valid to avoid missing real issues."#
148            .to_string()
149    }
150
151    /// Builds the validation prompt for a batch of findings.
152    fn build_batch_validation_prompt(
153        findings: &[Finding],
154        file_contents: &std::collections::HashMap<String, String>,
155    ) -> String {
156        use std::fmt::Write;
157
158        let mut prompt = String::new();
159        prompt.push_str("Analyze these security findings:\n\n");
160
161        for (i, finding) in findings.iter().enumerate() {
162            let _ = writeln!(prompt, "Finding {i}:");
163            let _ = writeln!(prompt, "  Pattern: {}", finding.pattern_id);
164            let _ = writeln!(prompt, "  Description: {}", finding.description);
165            let _ = writeln!(
166                prompt,
167                "  Severity: {:?}, Confidence: {:?}",
168                finding.severity, finding.confidence
169            );
170            let _ = writeln!(
171                prompt,
172                "  File: {}:{}",
173                finding.file_path, finding.line_number
174            );
175            let _ = writeln!(prompt, "  Matched: {}", finding.matched_text);
176
177            // Extract context snippet
178            if let Some(snippet) =
179                extract_snippet(file_contents.get(&finding.file_path), finding.line_number)
180            {
181                let _ = writeln!(prompt, "  Context:\n{snippet}");
182            }
183
184            prompt.push('\n');
185        }
186
187        prompt
188    }
189
190    /// Sends a validation request and parses the response.
191    async fn send_and_parse(
192        &self,
193        request: &ChatCompletionRequest,
194    ) -> Result<Vec<ValidationResult>> {
195        // Send request using AiProvider trait
196        let completion = self.ai_client.send_request_inner(request).await?;
197
198        // Extract message content
199        let content = completion
200            .choices
201            .first()
202            .map(|c| c.message.content.clone())
203            .context("No response from AI model")?;
204
205        // Parse JSON response
206        let response: ValidationResponse = serde_json::from_str(&content)
207            .context("Failed to parse validation response as JSON")?;
208
209        Ok(response.results)
210    }
211
212    /// Creates a fallback validated finding using pattern confidence.
213    fn fallback_validation(finding: &Finding) -> ValidatedFinding {
214        use super::types::Confidence;
215
216        let is_valid = matches!(finding.confidence, Confidence::High | Confidence::Medium);
217        let reasoning = format!(
218            "LLM validation unavailable, using pattern confidence: {:?}",
219            finding.confidence
220        );
221
222        ValidatedFinding {
223            finding: finding.clone(),
224            is_valid,
225            reasoning,
226            model_version: None,
227        }
228    }
229}
230
231/// Extracts a code snippet with context around a line number.
232///
233/// # Arguments
234///
235/// * `content` - File content
236/// * `line_number` - Target line number (1-indexed)
237///
238/// # Returns
239///
240/// Code snippet with up to `CONTEXT_LINES` before and after the target line
241fn extract_snippet(content: Option<&String>, line_number: usize) -> Option<String> {
242    use std::fmt::Write;
243
244    let content = content?;
245    let lines: Vec<&str> = content.lines().collect();
246
247    if line_number == 0 || line_number > lines.len() {
248        return None;
249    }
250
251    // Calculate range (1-indexed to 0-indexed)
252    let target_idx = line_number - 1;
253    let start = target_idx.saturating_sub(CONTEXT_LINES);
254    let end = (target_idx + CONTEXT_LINES + 1).min(lines.len());
255
256    let mut snippet = String::new();
257    for (i, line) in lines[start..end].iter().enumerate() {
258        let line_num = start + i + 1;
259        let marker = if line_num == line_number { ">" } else { " " };
260        let _ = writeln!(snippet, "{marker} {line_num:4} | {line}");
261    }
262
263    Some(snippet)
264}
265
266#[cfg(test)]
267mod tests {
268    use super::*;
269    use crate::security::types::{Confidence, Severity};
270
271    #[test]
272    fn test_extract_snippet_with_context() {
273        let content = "line 1\nline 2\nline 3\nline 4\nline 5\nline 6\nline 7\n".to_string();
274        let snippet = extract_snippet(Some(&content), 4);
275
276        assert!(snippet.is_some());
277        let snippet = snippet.unwrap();
278        assert!(snippet.contains(">    4 | line 4"));
279        assert!(snippet.contains("     1 | line 1"));
280        assert!(snippet.contains("     7 | line 7"));
281    }
282
283    #[test]
284    fn test_extract_snippet_at_start() {
285        let content = "line 1\nline 2\nline 3\n".to_string();
286        let snippet = extract_snippet(Some(&content), 1);
287
288        assert!(snippet.is_some());
289        let snippet = snippet.unwrap();
290        assert!(snippet.contains(">    1 | line 1"));
291        assert!(!snippet.contains("     0 |"));
292    }
293
294    #[test]
295    fn test_extract_snippet_at_end() {
296        let content = "line 1\nline 2\nline 3\n".to_string();
297        let snippet = extract_snippet(Some(&content), 3);
298
299        assert!(snippet.is_some());
300        let snippet = snippet.unwrap();
301        assert!(snippet.contains(">    3 | line 3"));
302    }
303
304    #[test]
305    fn test_extract_snippet_invalid_line() {
306        let content = "line 1\nline 2\n".to_string();
307        let snippet = extract_snippet(Some(&content), 10);
308
309        assert!(snippet.is_none());
310    }
311
312    #[test]
313    fn test_fallback_validation_high_confidence() {
314        let finding = Finding {
315            pattern_id: "test-pattern".to_string(),
316            description: "Test finding".to_string(),
317            severity: Severity::High,
318            confidence: Confidence::High,
319            file_path: "test.rs".to_string(),
320            line_number: 1,
321            matched_text: "test".to_string(),
322            cwe: None,
323        };
324
325        let validated = SecurityValidator::fallback_validation(&finding);
326        assert!(validated.is_valid);
327        assert!(validated.reasoning.contains("High"));
328    }
329
330    #[test]
331    fn test_fallback_validation_low_confidence() {
332        let finding = Finding {
333            pattern_id: "test-pattern".to_string(),
334            description: "Test finding".to_string(),
335            severity: Severity::High,
336            confidence: Confidence::Low,
337            file_path: "test.rs".to_string(),
338            line_number: 1,
339            matched_text: "test".to_string(),
340            cwe: None,
341        };
342
343        let validated = SecurityValidator::fallback_validation(&finding);
344        assert!(!validated.is_valid);
345        assert!(validated.reasoning.contains("Low"));
346    }
347
348    #[test]
349    fn test_build_system_prompt() {
350        let prompt = SecurityValidator::build_system_prompt();
351        assert!(prompt.contains("security code reviewer"));
352        assert!(prompt.contains("\"results\""));
353        assert!(prompt.contains("\"index\""));
354        assert!(prompt.contains("\"is_valid\""));
355        assert!(prompt.contains("\"reasoning\""));
356    }
357
358    #[test]
359    fn test_parse_validation_response() {
360        let json = r#"{
361            "results": [
362                {
363                    "index": 0,
364                    "is_valid": true,
365                    "reasoning": "This is a real vulnerability"
366                },
367                {
368                    "index": 1,
369                    "is_valid": false,
370                    "reasoning": "This is test data"
371                }
372            ]
373        }"#;
374
375        let response: ValidationResponse = serde_json::from_str(json).unwrap();
376        assert_eq!(response.results.len(), 2);
377        assert_eq!(response.results[0].index, 0);
378        assert!(response.results[0].is_valid);
379        assert_eq!(response.results[1].index, 1);
380        assert!(!response.results[1].is_valid);
381    }
382}