Skip to main content

agentshield/parser/
python.rs

1use std::path::{Path, PathBuf};
2
3use once_cell::sync::Lazy;
4use regex::Regex;
5
6use super::{LanguageParser, ParsedFile};
7use crate::error::Result;
8use crate::ir::execution_surface::*;
9use crate::ir::{ArgumentSource, Language, SourceLocation};
10
11pub struct PythonParser;
12
13// Dangerous subprocess/exec functions
14static SUBPROCESS_PATTERNS: Lazy<Vec<&str>> = Lazy::new(|| {
15    vec![
16        "subprocess.run",
17        "subprocess.call",
18        "subprocess.check_call",
19        "subprocess.check_output",
20        "subprocess.Popen",
21        "os.system",
22        "os.popen",
23        "os.exec",
24        "os.execv",
25        "os.execve",
26        "os.execvp",
27    ]
28});
29
30static NETWORK_PATTERNS: Lazy<Vec<&str>> = Lazy::new(|| {
31    vec![
32        "requests.get",
33        "requests.post",
34        "requests.put",
35        "requests.patch",
36        "requests.delete",
37        "requests.head",
38        "requests.request",
39        "urllib.request.urlopen",
40        "httpx.get",
41        "httpx.post",
42        "httpx.put",
43        "httpx.AsyncClient",
44        "aiohttp.ClientSession",
45    ]
46});
47
48static DYNAMIC_EXEC_PATTERNS: Lazy<Vec<&str>> =
49    Lazy::new(|| vec!["eval", "exec", "compile", "__import__"]);
50
51static SENSITIVE_ENV_VARS: Lazy<Regex> = Lazy::new(|| {
52    Regex::new(r"(?i)(AWS_|SECRET|TOKEN|PASSWORD|API_KEY|PRIVATE_KEY|CREDENTIALS|AUTH)").unwrap()
53});
54
55static FILE_READ_PATTERNS: Lazy<Vec<&str>> = Lazy::new(|| vec!["open", "pathlib.Path"]);
56
57// Regex to find function calls with arguments: func_name(args)
58static CALL_RE: Lazy<Regex> =
59    Lazy::new(|| Regex::new(r"(?m)(\w+(?:\.\w+)*)\s*\(([^)]*)\)").unwrap());
60
61// Regex to find os.environ / os.getenv patterns
62static ENV_ACCESS_RE: Lazy<Regex> = Lazy::new(|| {
63    Regex::new(
64        r#"(?m)os\.(?:environ\s*(?:\[\s*["']([^"']+)["']\s*\]|\.get\s*\(\s*["']([^"']+)["'])|getenv\s*\(\s*["']([^"']+)["']\s*\))"#,
65    )
66    .unwrap()
67});
68
69// Regex to find function definitions and their parameters
70static FUNC_DEF_RE: Lazy<Regex> =
71    Lazy::new(|| Regex::new(r"(?m)^\s*(?:async\s+)?def\s+(\w+)\s*\(([^)]*)\)").unwrap());
72
73impl LanguageParser for PythonParser {
74    fn language(&self) -> Language {
75        Language::Python
76    }
77
78    fn parse_file(&self, path: &Path, content: &str) -> Result<ParsedFile> {
79        let mut parsed = ParsedFile::default();
80        let file_path = PathBuf::from(path);
81
82        // Collect function parameter names for taint tracking
83        let mut param_names = std::collections::HashSet::new();
84        for cap in FUNC_DEF_RE.captures_iter(content) {
85            let params = &cap[2];
86            for param in params.split(',') {
87                let param = param.trim().split(':').next().unwrap_or("").trim();
88                let param = param.split('=').next().unwrap_or("").trim();
89                if !param.is_empty() && param != "self" && param != "cls" {
90                    param_names.insert(param.to_string());
91                }
92            }
93        }
94
95        // Scan line by line for patterns
96        for (line_idx, line) in content.lines().enumerate() {
97            let line_num = line_idx + 1;
98            let trimmed = line.trim();
99
100            // Skip comments
101            if trimmed.starts_with('#') {
102                continue;
103            }
104
105            // Check env var access
106            for cap in ENV_ACCESS_RE.captures_iter(line) {
107                let var_name = cap
108                    .get(1)
109                    .or_else(|| cap.get(2))
110                    .or_else(|| cap.get(3))
111                    .map(|m| m.as_str().to_string())
112                    .unwrap_or_default();
113                let is_sensitive = SENSITIVE_ENV_VARS.is_match(&var_name);
114                parsed.env_accesses.push(EnvAccess {
115                    var_name: ArgumentSource::Literal(var_name),
116                    is_sensitive,
117                    location: loc(&file_path, line_num),
118                });
119            }
120
121            // Check function calls
122            for cap in CALL_RE.captures_iter(line) {
123                let func_name = &cap[1];
124                let args_str = &cap[2];
125
126                let arg_source = classify_argument(args_str, &param_names);
127
128                // Subprocess/command execution
129                if SUBPROCESS_PATTERNS
130                    .iter()
131                    .any(|p| func_name.ends_with(p) || func_name == *p)
132                {
133                    parsed.commands.push(CommandInvocation {
134                        function: func_name.to_string(),
135                        command_arg: arg_source.clone(),
136                        location: loc(&file_path, line_num),
137                    });
138                }
139
140                // Network operations
141                if NETWORK_PATTERNS
142                    .iter()
143                    .any(|p| func_name.ends_with(p) || func_name == *p)
144                {
145                    let sends_data = func_name.contains("post")
146                        || func_name.contains("put")
147                        || func_name.contains("patch")
148                        || args_str.contains("data=")
149                        || args_str.contains("json=");
150                    let method = if func_name.contains("get") {
151                        Some("GET".into())
152                    } else if func_name.contains("post") {
153                        Some("POST".into())
154                    } else if func_name.contains("put") {
155                        Some("PUT".into())
156                    } else {
157                        None
158                    };
159                    parsed.network_operations.push(NetworkOperation {
160                        function: func_name.to_string(),
161                        url_arg: arg_source.clone(),
162                        method,
163                        sends_data,
164                        location: loc(&file_path, line_num),
165                    });
166                }
167
168                // Dynamic exec
169                if DYNAMIC_EXEC_PATTERNS.contains(&func_name) {
170                    parsed.dynamic_exec.push(DynamicExec {
171                        function: func_name.to_string(),
172                        code_arg: arg_source.clone(),
173                        location: loc(&file_path, line_num),
174                    });
175                }
176
177                // File operations (open with write mode)
178                if FILE_READ_PATTERNS
179                    .iter()
180                    .any(|p| func_name.ends_with(p) || func_name == *p)
181                {
182                    let op_type = if args_str.contains("'w")
183                        || args_str.contains("\"w")
184                        || args_str.contains("'a")
185                        || args_str.contains("\"a")
186                    {
187                        FileOpType::Write
188                    } else {
189                        FileOpType::Read
190                    };
191                    parsed.file_operations.push(FileOperation {
192                        operation: op_type,
193                        path_arg: arg_source.clone(),
194                        location: loc(&file_path, line_num),
195                    });
196                }
197            }
198        }
199
200        Ok(parsed)
201    }
202}
203
204/// Classify a call argument string to determine its source.
205fn classify_argument(
206    args_str: &str,
207    param_names: &std::collections::HashSet<String>,
208) -> ArgumentSource {
209    let first_arg = args_str.split(',').next().unwrap_or("").trim();
210
211    if first_arg.is_empty() {
212        return ArgumentSource::Unknown;
213    }
214
215    // String literal
216    if (first_arg.starts_with('"') && first_arg.ends_with('"'))
217        || (first_arg.starts_with('\'') && first_arg.ends_with('\''))
218    {
219        let val = &first_arg[1..first_arg.len() - 1];
220        return ArgumentSource::Literal(val.to_string());
221    }
222
223    // f-string or format
224    if first_arg.starts_with("f\"") || first_arg.starts_with("f'") || first_arg.contains(".format(")
225    {
226        return ArgumentSource::Interpolated;
227    }
228
229    // os.environ / env var
230    if first_arg.contains("os.environ") || first_arg.contains("os.getenv") {
231        return ArgumentSource::EnvVar {
232            name: first_arg.to_string(),
233        };
234    }
235
236    // Known function parameter
237    let ident = first_arg.split('.').next().unwrap_or(first_arg);
238    if param_names.contains(ident) {
239        return ArgumentSource::Parameter {
240            name: ident.to_string(),
241        };
242    }
243
244    ArgumentSource::Unknown
245}
246
247fn loc(file: &Path, line: usize) -> SourceLocation {
248    SourceLocation {
249        file: file.to_path_buf(),
250        line,
251        column: 0,
252        end_line: None,
253        end_column: None,
254    }
255}
256
257#[cfg(test)]
258mod tests {
259    use super::*;
260
261    #[test]
262    fn detects_subprocess_with_param() {
263        let code = r#"
264def handle(cmd: str):
265    subprocess.run(cmd, shell=True)
266"#;
267        let parsed = PythonParser.parse_file(Path::new("test.py"), code).unwrap();
268        assert_eq!(parsed.commands.len(), 1);
269        assert!(matches!(
270            parsed.commands[0].command_arg,
271            ArgumentSource::Parameter { .. }
272        ));
273    }
274
275    #[test]
276    fn detects_requests_get_with_param() {
277        let code = r#"
278def fetch(url: str):
279    requests.get(url)
280"#;
281        let parsed = PythonParser.parse_file(Path::new("test.py"), code).unwrap();
282        assert_eq!(parsed.network_operations.len(), 1);
283        assert!(matches!(
284            parsed.network_operations[0].url_arg,
285            ArgumentSource::Parameter { .. }
286        ));
287    }
288
289    #[test]
290    fn safe_literal_not_flagged_as_param() {
291        let code = r#"
292def fetch():
293    requests.get("https://api.example.com")
294"#;
295        let parsed = PythonParser.parse_file(Path::new("test.py"), code).unwrap();
296        assert_eq!(parsed.network_operations.len(), 1);
297        assert!(matches!(
298            parsed.network_operations[0].url_arg,
299            ArgumentSource::Literal(_)
300        ));
301    }
302
303    #[test]
304    fn detects_env_var_access() {
305        let code = r#"
306key = os.environ["AWS_SECRET_ACCESS_KEY"]
307"#;
308        let parsed = PythonParser.parse_file(Path::new("test.py"), code).unwrap();
309        assert_eq!(parsed.env_accesses.len(), 1);
310        assert!(parsed.env_accesses[0].is_sensitive);
311    }
312
313    #[test]
314    fn detects_eval() {
315        let code = r#"
316def run(code):
317    eval(code)
318"#;
319        let parsed = PythonParser.parse_file(Path::new("test.py"), code).unwrap();
320        assert_eq!(parsed.dynamic_exec.len(), 1);
321        assert!(matches!(
322            parsed.dynamic_exec[0].code_arg,
323            ArgumentSource::Parameter { .. }
324        ));
325    }
326}