Skip to main content

oxi/
output_guard.rs

1//! Output guard for checking assistant output for sensitive data
2//!
3//! Provides utilities to scan assistant output for potentially sensitive
4//! information like API keys, passwords, and other secrets.
5
6use regex::Regex;
7use std::sync::LazyLock;
8
9/// Pattern for detecting various sensitive data
10static API_KEY_PATTERNS: LazyLock<Vec<(Regex, &str, &str)>> = LazyLock::new(|| {
11    vec![
12        // Generic API keys
13        (
14            Regex::new(r"(?i)(api[_-]?key|apikey|api[_-]?secret)\s*[:=]\s*\S{8,}").unwrap(),
15            "api_key",
16            "Potential API key detected",
17        ),
18        // AWS keys
19        (
20            Regex::new(r"(?i)AKIA[0-9A-Z]{16}").unwrap(),
21            "aws_access_key",
22            "AWS access key ID detected",
23        ),
24        // AWS secret
25        (
26            Regex::new(r"(?i)aws[_-]?secret[_-]?access[_-]?key\s*[:=]\s*[a-zA-Z0-9/+=]{40}").unwrap(),
27            "aws_secret",
28            "AWS secret access key detected",
29        ),
30        // GitHub tokens
31        (
32            Regex::new(r"ghp_[a-zA-Z0-9]{36}").unwrap(),
33            "github_token",
34            "GitHub personal access token detected",
35        ),
36        (
37            Regex::new(r"gho_[a-zA-Z0-9]{36}").unwrap(),
38            "github_token",
39            "GitHub OAuth token detected",
40        ),
41        // Private keys
42        (
43            Regex::new(r"-----BEGIN (?:RSA |EC |DSA |OPENSSH )?PRIVATE KEY-----").unwrap(),
44            "private_key",
45            "Private key detected",
46        ),
47        // Bearer tokens
48        (
49            Regex::new(r"(?i)bearer\s+[a-zA-Z0-9_\-\.]{20,}").unwrap(),
50            "bearer_token",
51            "Bearer token detected",
52        ),
53        // Basic auth
54        (
55            Regex::new(r"(?i)basic\s+[a-zA-Z0-9+/=]{20,}").unwrap(),
56            "basic_auth",
57            "Basic auth credentials detected",
58        ),
59        // Database URLs with passwords
60        (
61            Regex::new(r"(?i)(postgres|mysql|mongodb|redis)://[^:]+:[^@]+@").unwrap(),
62            "db_url",
63            "Database URL with credentials detected",
64        ),
65        // Slack tokens
66        (
67            Regex::new(r"xox[baprs]-[0-9]{10,13}-[0-9]{10,13}-[a-zA-Z0-9]{24,}").unwrap(),
68            "slack_token",
69            "Slack token detected",
70        ),
71        // Discord tokens
72        (
73            Regex::new(r"[MN][A-Za-z\d]{23,}\.[\w-]{6}\.[\w-]{27}").unwrap(),
74            "discord_token",
75            "Discord token detected",
76        ),
77        // JWT tokens
78        (
79            Regex::new(r"eyJ[a-zA-Z0-9_-]*\.eyJ[a-zA-Z0-9_-]*\.[a-zA-Z0-9_-]*").unwrap(),
80            "jwt",
81            "JWT token detected",
82        ),
83        // Generic secrets
84        (
85            Regex::new(r"(?i)(secret|password|passwd|pwd|token|auth)\s*[:=]\s*\S{8,}").unwrap(),
86            "generic_secret",
87            "Potential secret detected",
88        ),
89        // SSH keys
90        (
91            Regex::new(r"ssh-rsa\s+[A-Za-z0-9+/=]{30,}").unwrap(),
92            "ssh_key",
93            "SSH key detected",
94        ),
95    ]
96});
97
98/// Result of an output scan
99#[derive(Debug, Clone)]
100pub struct ScanResult {
101    /// Whether any sensitive data was found
102    pub has_sensitive_data: bool,
103    /// List of detected items
104    pub findings: Vec<Finding>,
105}
106
107/// A single finding of sensitive data
108#[derive(Debug, Clone)]
109pub struct Finding {
110    /// Type of sensitive data
111    pub category: String,
112    /// Description of the finding
113    pub description: String,
114    /// The matched text (redacted in output)
115    pub matched_text: String,
116    /// Start position in the original text
117    pub start: usize,
118    /// End position in the original text
119    pub end: usize,
120}
121
122impl Finding {
123    /// Get a redacted version of the matched text
124    pub fn redacted(&self) -> String {
125        if self.matched_text.len() <= 8 {
126            "*".repeat(self.matched_text.len())
127        } else {
128            format!(
129                "{}...{}",
130                &self.matched_text[..4],
131                &self.matched_text[self.matched_text.len() - 4..]
132            )
133        }
134    }
135}
136
137/// Scan output for sensitive data
138///
139/// # Arguments
140/// * `output` - The text to scan
141/// * `strict` - If true, warn on more patterns (may have false positives)
142///
143/// # Returns
144/// A scan result with findings
145pub fn scan_output(output: &str, strict: bool) -> ScanResult {
146    let mut findings = Vec::new();
147
148    for (pattern, category, description) in API_KEY_PATTERNS.iter() {
149        // In non-strict mode, skip generic patterns that may have false positives
150        if !strict {
151            if *category == "generic_secret" || *category == "api_key" {
152                continue;
153            }
154        }
155
156        for mat in pattern.find_iter(output) {
157            findings.push(Finding {
158                category: category.to_string(),
159                description: description.to_string(),
160                matched_text: mat.as_str().to_string(),
161                start: mat.start(),
162                end: mat.end(),
163            });
164        }
165    }
166
167    ScanResult {
168        has_sensitive_data: !findings.is_empty(),
169        findings,
170    }
171}
172
173/// Scan and warn about sensitive data
174///
175/// Prints warnings to stderr but does not modify the output.
176pub fn warn_about_sensitive_data(output: &str) -> ScanResult {
177    let result = scan_output(output, false);
178
179    if result.has_sensitive_data {
180        for finding in &result.findings {
181            eprintln!(
182                "Warning: {} at position {}: {}",
183                finding.description,
184                finding.start,
185                finding.redacted()
186            );
187        }
188    }
189
190    result
191}
192
193/// Redact sensitive data from output
194///
195/// Returns the output with sensitive data replaced by [REDACTED].
196pub fn redact_sensitive_data(output: &str) -> String {
197    let mut result = output.to_string();
198
199    for (pattern, _, _) in API_KEY_PATTERNS.iter() {
200        result = pattern.replace_all(&result, "[REDACTED]").to_string();
201    }
202
203    result
204}
205
206/// Check if a specific string looks like a sensitive value
207pub fn is_sensitive_pattern(s: &str) -> bool {
208    if s.len() < 8 {
209        return false;
210    }
211
212    let patterns = [
213        r"^[a-zA-Z0-9_\-]{20,}$",
214        r"^xox[baprs]-",
215        r"^gh[pso]_[a-zA-Z0-9]{36}",
216        r"^AKIA[0-9A-Z]{16}$",
217        r"^Bearer\s+",
218    ];
219
220    patterns.iter().any(|p| {
221        Regex::new(p)
222            .map(|re| re.is_match(s))
223            .unwrap_or(false)
224    })
225}
226
227/// Get a list of all supported categories
228pub fn supported_categories() -> Vec<&'static str> {
229    API_KEY_PATTERNS
230        .iter()
231        .map(|(_, category, _)| *category)
232        .collect::<std::collections::HashSet<_>>()
233        .into_iter()
234        .collect()
235}
236
237#[cfg(test)]
238mod tests {
239    use super::*;
240
241    #[test]
242    fn test_scan_no_sensitive_data() {
243        let output = "Hello, this is a normal response without any secrets.";
244        let result = scan_output(output, false);
245        assert!(!result.has_sensitive_data);
246        assert!(result.findings.is_empty());
247    }
248
249    #[test]
250    fn test_scan_aws_key() {
251        let output = "AWS Key: AKIAIOSFODNN7EXAMPLE";
252        let result = scan_output(output, false);
253        assert!(result.has_sensitive_data);
254        assert_eq!(result.findings[0].category, "aws_access_key");
255    }
256
257    #[test]
258    fn test_scan_github_token() {
259        let output = "ghp_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx";
260        let result = scan_output(output, false);
261        assert!(result.has_sensitive_data);
262        assert_eq!(result.findings[0].category, "github_token");
263    }
264
265    #[test]
266    fn test_scan_private_key() {
267        let output = "-----BEGIN RSA PRIVATE KEY-----\nMIIEowIBAAKCAQ...\n-----END RSA PRIVATE KEY-----";
268        let result = scan_output(output, false);
269        assert!(result.has_sensitive_data);
270        assert_eq!(result.findings[0].category, "private_key");
271    }
272
273    #[test]
274    fn test_scan_jwt() {
275        let output = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.dozjgNryP4J3jVmNHl0w5N_XgL0n3I9PlFUP0THsR8U";
276        let result = scan_output(output, false);
277        assert!(result.has_sensitive_data);
278        assert_eq!(result.findings[0].category, "jwt");
279    }
280
281    #[test]
282    fn test_scan_db_url() {
283        let output = "postgres://user:password@localhost:5432/mydb";
284        let result = scan_output(output, false);
285        assert!(result.has_sensitive_data);
286        assert_eq!(result.findings[0].category, "db_url");
287    }
288
289    #[test]
290    fn test_redact() {
291        let output = "My GitHub token is ghp_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx";
292        let redacted = redact_sensitive_data(output);
293        assert!(redacted.contains("[REDACTED]"));
294        assert!(!redacted.contains("ghp_"));
295    }
296
297    #[test]
298    fn test_finding_redacted() {
299        let finding = Finding {
300            category: "github_token".to_string(),
301            description: "GitHub token".to_string(),
302            matched_text: "ghp_abcdefghij1234567890abcdefghij12".to_string(),
303            start: 0,
304            end: 45,
305        };
306        let redacted = finding.redacted();
307        assert!(redacted.starts_with("ghp_"));
308        assert!(redacted.ends_with("ij12"));
309        assert!(redacted.contains("..."));
310    }
311
312    #[test]
313    fn test_is_sensitive_pattern() {
314        assert!(is_sensitive_pattern("ghp_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"));
315        assert!(is_sensitive_pattern("AKIAIOSFODNN7EXAMPLE"));
316        assert!(!is_sensitive_pattern("hello"));
317        assert!(!is_sensitive_pattern("short"));
318    }
319
320    #[test]
321    fn test_supported_categories() {
322        let categories = supported_categories();
323        assert!(categories.contains(&"aws_access_key"));
324        assert!(categories.contains(&"github_token"));
325        assert!(categories.contains(&"private_key"));
326    }
327
328    #[test]
329    fn test_strict_vs_non_strict() {
330        // Use a pattern caught by both strict and non-strict modes
331        // AWS keys are detected in both modes
332        let output = "AWS Key: AKIAIOSFODNN7EXAMPLE";
333        let non_strict = scan_output(output, false);
334        let strict = scan_output(output, true);
335        assert!(non_strict.has_sensitive_data || !non_strict.findings.is_empty(), "non_strict should detect AWS keys");
336        assert!(strict.has_sensitive_data || !strict.findings.is_empty(), "strict should detect AWS keys");
337        // strict mode should detect more
338        assert!(strict.findings.len() >= non_strict.findings.len());
339    }
340}