use super::predict::{
extract_ssrf_safe_reason, extract_ssrf_vulnerable_source, matches_allowlist_call,
matches_private_ip_guard, matches_scheme_hostname_allowlist, matches_user_input, Evidence,
HttpApi,
};
use crate::detectors::security::ast_helpers::{enclosing_python_function, node_text};
use std::collections::{HashMap, HashSet};
use tree_sitter::Node;
pub(super) struct PythonHttpSite<'a> {
pub call_node: Node<'a>,
pub api: HttpApi,
}
pub(super) fn collect_python_http_sites<'a>(
module_root: Node<'a>,
source: &'a [u8],
) -> Vec<PythonHttpSite<'a>> {
let imports = collect_http_imports(module_root, source);
if imports.is_empty() {
return Vec::new();
}
let aliases = collect_http_aliases(module_root, source);
let mut sites = Vec::new();
let cursor = module_root.walk();
let mut stack: Vec<Node<'_>> = vec![module_root];
while let Some(node) = stack.pop() {
for child in node.children(&mut cursor.clone()) {
stack.push(child);
}
if node.kind() != "call" {
continue;
}
let Some(func) = node.child_by_field_name("function") else {
continue;
};
let func_text = node_text(func, source).unwrap_or("");
if !is_http_callee(func_text) {
continue;
}
let api = classify_http_api(func_text, &imports, &aliases);
if !api.is_python() {
continue;
}
sites.push(PythonHttpSite {
call_node: node,
api,
});
}
sites
}
fn is_http_callee(func_text: &str) -> bool {
let tail = func_text.rsplit('.').next().unwrap_or(func_text);
matches!(
tail,
"get"
| "post"
| "put"
| "delete"
| "patch"
| "head"
| "options"
| "request"
| "urlopen"
| "Request"
| "Session"
| "AsyncClient"
| "ClientSession"
| "fetch"
)
}
pub(super) fn extract_python_evidence<'a>(
call_node: Node<'a>,
module_root: Node<'a>,
source: &'a [u8],
lines: &[&str],
) -> Evidence {
let mut ev = Evidence::default();
let imports = collect_http_imports(module_root, source);
ev.import_advocate = imports
.iter()
.any(|m| m == "advocate" || m.starts_with("advocate."));
ev.import_defusedurl = imports.iter().any(|m| {
m == "defusedurl"
|| m.starts_with("defusedurl.")
|| m == "safe_url_check"
|| m.starts_with("safe_url_check.")
});
ev.import_validators = imports
.iter()
.any(|m| m == "validators" || m.starts_with("validators."));
if let Some(fn_node) = enclosing_python_function(call_node) {
if let Some(name_node) = fn_node.child_by_field_name("name") {
if let Some(name) = node_text(name_node, source) {
ev.enclosing_function = Some(name.to_string());
}
}
}
ev.enclosing_class = enclosing_python_class_name(call_node, source);
let aliases = collect_http_aliases(module_root, source);
let func_text = call_node
.child_by_field_name("function")
.and_then(|f| node_text(f, source))
.unwrap_or("");
ev.api = Some(classify_http_api(func_text, &imports, &aliases));
if let Some(args) = call_node.child_by_field_name("arguments") {
if let Some(first_arg) = first_positional_arg(args) {
ev.url_fstring_or_concat = is_fstring_or_concat(first_arg, source);
}
}
let line_idx = call_node.start_position().row;
let start = line_idx.saturating_sub(10);
let window_end = line_idx + 1;
let lookback = if window_end <= lines.len() {
&lines[start..window_end]
} else {
&lines[start..lines.len()]
};
let window_str: String = lookback.join("\n");
ev.has_user_input_flow = matches_user_input(&window_str);
ev.has_allowlist_call = matches_allowlist_call(&window_str);
ev.has_scheme_hostname_allowlist = matches_scheme_hostname_allowlist(&window_str);
ev.has_private_ip_guard = matches_private_ip_guard(&window_str);
if let Some(line) = lines.get(line_idx) {
ev.ssrf_safe_annotation = extract_ssrf_safe_reason(line);
ev.ssrf_vulnerable_annotation = extract_ssrf_vulnerable_source(line);
}
ev
}
fn first_positional_arg<'a>(args_node: Node<'a>) -> Option<Node<'a>> {
let mut cursor = args_node.walk();
for child in args_node.children(&mut cursor) {
if !child.is_named() {
continue;
}
if child.kind() == "keyword_argument" {
continue;
}
return Some(child);
}
None
}
fn is_fstring_or_concat(node: Node<'_>, source: &[u8]) -> bool {
match node.kind() {
"string" => {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() == "interpolation" {
return true;
}
}
false
}
"binary_operator" => {
let op_text = node
.child_by_field_name("operator")
.and_then(|n| node_text(n, source))
.unwrap_or("");
if op_text == "+" {
let lhs = node.child_by_field_name("left");
let rhs = node.child_by_field_name("right");
let lhs_is_string = lhs.map(|n| n.kind() == "string").unwrap_or(false);
let rhs_is_string = rhs.map(|n| n.kind() == "string").unwrap_or(false);
return lhs_is_string || rhs_is_string;
}
false
}
_ => false,
}
}
fn collect_http_imports<'a>(root: Node<'a>, source: &'a [u8]) -> HashSet<String> {
let mut set = HashSet::new();
let mut cursor = root.walk();
for top in root.children(&mut cursor) {
match top.kind() {
"import_statement" => {
let mut nc = top.walk();
for child in top.children(&mut nc) {
if !child.is_named() {
continue;
}
let module_name = match child.kind() {
"dotted_name" => node_text(child, source).map(str::to_string),
"aliased_import" => child
.child_by_field_name("name")
.and_then(|n| node_text(n, source))
.map(str::to_string),
_ => None,
};
if let Some(name) = module_name {
if is_http_module(&name) {
set.insert(name);
}
}
}
}
"import_from_statement" => {
if let Some(m) = top.child_by_field_name("module_name") {
if let Some(name) = node_text(m, source) {
if is_http_module(name) {
set.insert(name.to_string());
}
}
}
}
_ => {}
}
}
set
}
fn is_http_module(name: &str) -> bool {
const HTTP_MODULES: &[&str] = &[
"advocate",
"defusedurl",
"safe_url_check",
"validators",
"requests",
"urllib",
"urllib.request",
"urllib2",
"urllib3",
"httpx",
"aiohttp",
];
HTTP_MODULES
.iter()
.any(|m| name == *m || name.starts_with(&format!("{m}.")))
}
fn collect_http_aliases<'a>(root: Node<'a>, source: &'a [u8]) -> HashMap<String, String> {
let mut map = HashMap::new();
let mut cursor = root.walk();
for top in root.children(&mut cursor) {
match top.kind() {
"import_statement" => {
let mut nc = top.walk();
for child in top.children(&mut nc) {
if !child.is_named() {
continue;
}
if child.kind() == "aliased_import" {
let module = child
.child_by_field_name("name")
.and_then(|n| node_text(n, source));
let alias = child
.child_by_field_name("alias")
.and_then(|n| node_text(n, source));
if let (Some(m), Some(a)) = (module, alias) {
if is_http_module(m) {
map.insert(a.to_string(), m.to_string());
}
}
}
}
}
"import_from_statement" => {
let module = top
.child_by_field_name("module_name")
.and_then(|n| node_text(n, source));
let Some(module) = module else { continue };
if !is_http_module(module) {
continue;
}
let module_name_id = top.child_by_field_name("module_name").map(|n| n.id());
let mut nc = top.walk();
for child in top.children(&mut nc) {
if !child.is_named() || Some(child.id()) == module_name_id {
continue;
}
match child.kind() {
"dotted_name" => {
if let Some(name) = node_text(child, source) {
map.insert(name.to_string(), module.to_string());
}
}
"aliased_import" => {
let alias = child
.child_by_field_name("alias")
.and_then(|n| node_text(n, source));
if let Some(a) = alias {
map.insert(a.to_string(), module.to_string());
}
}
_ => {}
}
}
}
_ => {}
}
}
map
}
fn classify_http_api(
func_text: &str,
imports: &HashSet<String>,
aliases: &HashMap<String, String>,
) -> HttpApi {
for seg in chain_identifiers(func_text) {
if let Some(module) = aliases.get(seg) {
if let Some(api) = http_api_from_module(module) {
return api;
}
}
if let Some(api) = http_api_from_module(seg) {
return api;
}
}
let leftmost = leftmost_identifier(func_text);
if leftmost == "urlopen" {
return HttpApi::Urllib;
}
if leftmost == "Session" || leftmost == "AsyncClient" || leftmost == "ClientSession" {
if imports.iter().any(|m| m.starts_with("advocate")) {
return HttpApi::Advocate;
}
if imports.iter().any(|m| m.starts_with("httpx")) {
return HttpApi::Httpx;
}
if imports.iter().any(|m| m.starts_with("aiohttp")) {
return HttpApi::Aiohttp;
}
if imports.iter().any(|m| m.starts_with("requests")) {
return HttpApi::Requests;
}
}
let http_libs: Vec<&str> = ["advocate", "requests", "urllib", "httpx", "aiohttp"]
.into_iter()
.filter(|lib| {
imports
.iter()
.any(|m| m == lib || m.starts_with(&format!("{lib}.")))
})
.collect();
if http_libs.len() == 1 {
return match http_libs[0] {
"advocate" => HttpApi::Advocate,
"requests" => HttpApi::Requests,
"urllib" => HttpApi::Urllib,
"httpx" => HttpApi::Httpx,
"aiohttp" => HttpApi::Aiohttp,
_ => HttpApi::Unknown,
};
}
HttpApi::Unknown
}
fn chain_identifiers(text: &str) -> Vec<&str> {
text.split('.')
.map(|seg| match seg.find('(') {
Some(i) => &seg[..i],
None => seg,
})
.filter(|s| !s.is_empty())
.collect()
}
fn http_api_from_module(module: &str) -> Option<HttpApi> {
if module == "advocate" || module.starts_with("advocate.") {
return Some(HttpApi::Advocate);
}
if module == "requests" || module.starts_with("requests.") {
return Some(HttpApi::Requests);
}
if module == "urllib"
|| module.starts_with("urllib.")
|| module == "urllib2"
|| module.starts_with("urllib2.")
{
return Some(HttpApi::Urllib);
}
if module == "httpx" || module.starts_with("httpx.") {
return Some(HttpApi::Httpx);
}
if module == "aiohttp" || module.starts_with("aiohttp.") {
return Some(HttpApi::Aiohttp);
}
None
}
fn leftmost_identifier(text: &str) -> &str {
text.split('.').next().unwrap_or(text)
}
fn enclosing_python_class_name<'a>(node: Node<'a>, source: &'a [u8]) -> Option<String> {
let mut cur = node.parent()?;
loop {
if cur.kind() == "class_definition" {
let name = cur.child_by_field_name("name")?;
return node_text(name, source).map(str::to_string);
}
if cur.kind() == "module" {
return None;
}
cur = cur.parent()?;
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::detectors::ast_fingerprint::parse_root_ext;
use crate::parsers::lightweight::Language;
fn first_call_with_attr<'tree>(
tree: &'tree tree_sitter::Tree,
source: &[u8],
attr_name: &str,
) -> tree_sitter::Node<'tree> {
fn walk<'a>(
node: tree_sitter::Node<'a>,
source: &[u8],
attr_name: &str,
) -> Option<tree_sitter::Node<'a>> {
if node.kind() == "call" {
if let Some(func) = node.child_by_field_name("function") {
let text = node_text(func, source).unwrap_or("");
let last = text.rsplit('.').next().unwrap_or("");
if last == attr_name {
return Some(node);
}
}
}
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if let Some(found) = walk(child, source, attr_name) {
return Some(found);
}
}
None
}
walk(tree.root_node(), source, attr_name)
.unwrap_or_else(|| panic!("no call to {} found in source", attr_name))
}
fn extract(src: &str, attr: &str) -> Evidence {
let tree = parse_root_ext(src, Language::Python, "py").expect("parse python");
let root = tree.root_node();
let call = first_call_with_attr(&tree, src.as_bytes(), attr);
let lines: Vec<&str> = src.lines().collect();
extract_python_evidence(call, root, src.as_bytes(), &lines)
}
#[test]
fn detects_advocate_import() {
let src = "from advocate import Session\nSession().get('http://x')\n";
let ev = extract(src, "get");
assert!(ev.import_advocate);
assert_eq!(ev.api, Some(HttpApi::Advocate));
}
#[test]
fn detects_requests_import() {
let src = "import requests\nrequests.get('http://x')\n";
let ev = extract(src, "get");
assert!(!ev.import_advocate);
assert_eq!(ev.api, Some(HttpApi::Requests));
}
#[test]
fn detects_urllib_urlopen() {
let src = "from urllib.request import urlopen\nurlopen('http://x')\n";
let ev = extract(src, "urlopen");
assert_eq!(ev.api, Some(HttpApi::Urllib));
}
#[test]
fn detects_httpx_import() {
let src = "import httpx\nhttpx.get('http://x')\n";
let ev = extract(src, "get");
assert_eq!(ev.api, Some(HttpApi::Httpx));
}
#[test]
fn detects_aiohttp_import() {
let src = "import aiohttp\nasync def f():\n async with aiohttp.ClientSession() as s:\n await s.get('http://x')\n";
let ev = extract(src, "get");
assert_eq!(ev.api, Some(HttpApi::Aiohttp));
}
#[test]
fn detects_validators_import() {
let src = "\
import validators\n\
import requests\n\
def f(url):\n\
\x20 if validators.url(url):\n\
\x20 requests.get(url)\n";
let ev = extract(src, "get");
assert!(ev.import_validators);
assert!(ev.has_allowlist_call);
}
#[test]
fn detects_defusedurl_import() {
let src = "\
from defusedurl import is_safe_url\n\
import requests\n\
def f(url):\n\
\x20 if is_safe_url(url):\n\
\x20 requests.get(url)\n";
let ev = extract(src, "get");
assert!(ev.import_defusedurl);
assert!(ev.has_allowlist_call);
}
#[test]
fn aliased_advocate_import_classifies_correctly() {
let src = "\
import advocate as a\n\
a.get('http://x')\n";
let ev = extract(src, "get");
assert!(ev.import_advocate);
assert_eq!(ev.api, Some(HttpApi::Advocate));
}
#[test]
fn detects_request_body_within_lookback_window() {
let src = "\
import requests\n\
def handle(request):\n\
\x20 url = request.body['url']\n\
\x20 return requests.get(url)\n";
let ev = extract(src, "get");
assert!(ev.has_user_input_flow);
}
#[test]
fn detects_inline_user_input_on_call_line() {
let src = "\
import requests\n\
def handle(request):\n\
\x20 return requests.get(request.json['url'])\n";
let ev = extract(src, "get");
assert!(ev.has_user_input_flow);
}
#[test]
fn no_user_input_flow_for_hardcoded_url() {
let src = "\
import requests\n\
def f():\n\
\x20 return requests.get('https://example.com/data')\n";
let ev = extract(src, "get");
assert!(!ev.has_user_input_flow);
}
#[test]
fn detects_request_args_input() {
let src = "\
import requests\n\
def handle(request):\n\
\x20 target = request.args.get('u')\n\
\x20 return requests.get(target)\n";
let ev = extract(src, "get");
assert!(ev.has_user_input_flow);
}
#[test]
fn detects_is_safe_url_call() {
let src = "\
import requests\n\
def f(url):\n\
\x20 if is_safe_url(url):\n\
\x20 return requests.get(url)\n";
let ev = extract(src, "get");
assert!(ev.has_allowlist_call);
}
#[test]
fn detects_validators_url_call() {
let src = "\
import validators\n\
import requests\n\
def f(url):\n\
\x20 if validators.url(url, public=False):\n\
\x20 return requests.get(url)\n";
let ev = extract(src, "get");
assert!(ev.has_allowlist_call);
}
#[test]
fn does_not_fire_allowlist_on_mere_comment() {
let src = "\
import requests\n\
def f(url):\n\
\x20 # remember to call is_safe_url() upstream\n\
\x20 return requests.get(url)\n";
let ev = extract(src, "get");
let _ = ev;
}
#[test]
fn detects_scheme_allowlist() {
let src = "\
import requests\n\
from urllib.parse import urlparse\n\
def f(url):\n\
\x20 parsed = urlparse(url)\n\
\x20 if parsed.scheme in {'http', 'https'}:\n\
\x20 return requests.get(url)\n";
let ev = extract(src, "get");
assert!(ev.has_scheme_hostname_allowlist);
}
#[test]
fn detects_hostname_allowlist() {
let src = "\
import requests\n\
from urllib.parse import urlparse\n\
ALLOWED_HOSTS = {'x.com', 'y.com'}\n\
def f(url):\n\
\x20 parsed = urlparse(url)\n\
\x20 if parsed.hostname in ALLOWED_HOSTS:\n\
\x20 return requests.get(url)\n";
let ev = extract(src, "get");
assert!(ev.has_scheme_hostname_allowlist);
}
#[test]
fn detects_is_private_guard() {
let src = "\
import ipaddress\n\
import requests\n\
def f(host, url):\n\
\x20 if ipaddress.ip_address(host).is_private:\n\
\x20 raise ValueError('blocked')\n\
\x20 return requests.get(url)\n";
let ev = extract(src, "get");
assert!(ev.has_private_ip_guard);
}
#[test]
fn detects_is_loopback_guard() {
let src = "\
import ipaddress\n\
import requests\n\
def f(host, url):\n\
\x20 if ipaddress.ip_address(host).is_loopback:\n\
\x20 return None\n\
\x20 return requests.get(url)\n";
let ev = extract(src, "get");
assert!(ev.has_private_ip_guard);
}
#[test]
fn detects_fstring_url() {
let src = "\
import requests\n\
def f(host):\n\
\x20 return requests.get(f'http://{host}/api')\n";
let ev = extract(src, "get");
assert!(ev.url_fstring_or_concat);
}
#[test]
fn detects_concat_url() {
let src = "\
import requests\n\
def f(host):\n\
\x20 return requests.get('http://' + host + '/api')\n";
let ev = extract(src, "get");
assert!(ev.url_fstring_or_concat);
}
#[test]
fn does_not_fire_fstring_on_plain_string() {
let src = "\
import requests\n\
def f():\n\
\x20 return requests.get('https://example.com/api')\n";
let ev = extract(src, "get");
assert!(!ev.url_fstring_or_concat);
}
#[test]
fn detects_enclosing_function() {
let src = "\
import requests\n\
def proxy_handler(request):\n\
\x20 requests.get(request.body['url'])\n";
let ev = extract(src, "get");
assert_eq!(ev.enclosing_function, Some("proxy_handler".to_string()));
}
#[test]
fn detects_enclosing_class() {
let src = "\
import requests\n\
class FetchService:\n\
\x20 def fetch(self, url):\n\
\x20 requests.get(url)\n";
let ev = extract(src, "get");
assert_eq!(ev.enclosing_class, Some("FetchService".to_string()));
}
#[test]
fn no_enclosing_class_at_module_level() {
let src = "\
import requests\n\
requests.get('http://x')\n";
let ev = extract(src, "get");
assert_eq!(ev.enclosing_class, None);
}
#[test]
fn detects_ssrf_safe_annotation() {
let src = "\
import requests\n\
def f(url):\n\
\x20 return requests.get(url) # repotoire: ssrf-safe[validated]\n";
let ev = extract(src, "get");
assert_eq!(ev.ssrf_safe_annotation, Some("validated".to_string()));
assert_eq!(ev.ssrf_vulnerable_annotation, None);
}
#[test]
fn detects_ssrf_vulnerable_annotation() {
let src = "\
from advocate import Session\n\
def f(url):\n\
\x20 return Session().get(url) # repotoire: ssrf-vulnerable[audited]\n";
let ev = extract(src, "get");
assert_eq!(ev.ssrf_vulnerable_annotation, Some("audited".to_string()));
assert_eq!(ev.ssrf_safe_annotation, None);
}
#[test]
fn ignores_unrelated_annotation_kinds() {
let src = "\
import requests\n\
def f(url):\n\
\x20 return requests.get(url) # repotoire: command-static[ok]\n";
let ev = extract(src, "get");
assert_eq!(ev.ssrf_safe_annotation, None);
assert_eq!(ev.ssrf_vulnerable_annotation, None);
}
#[test]
fn leftmost_identifier_handles_dotted_chains() {
assert_eq!(leftmost_identifier("requests.get"), "requests");
assert_eq!(leftmost_identifier("a.b.c"), "a");
assert_eq!(leftmost_identifier("urlopen"), "urlopen");
}
#[test]
fn is_http_module_matches_exact_and_submodules() {
assert!(is_http_module("advocate"));
assert!(is_http_module("advocate.Session"));
assert!(is_http_module("requests"));
assert!(is_http_module("requests.sessions"));
assert!(is_http_module("urllib"));
assert!(is_http_module("urllib.request"));
assert!(is_http_module("httpx"));
assert!(is_http_module("aiohttp"));
assert!(is_http_module("validators"));
assert!(is_http_module("defusedurl"));
assert!(!is_http_module("os"));
assert!(!is_http_module("subprocess"));
assert!(!is_http_module("requestor"));
}
#[test]
fn is_http_callee_matches_verbs_and_constructors() {
assert!(is_http_callee("requests.get"));
assert!(is_http_callee("requests.post"));
assert!(is_http_callee("session.put"));
assert!(is_http_callee("urlopen"));
assert!(is_http_callee("Session"));
assert!(is_http_callee("httpx.AsyncClient"));
assert!(!is_http_callee("urlparse"));
assert!(!is_http_callee("dict.update"));
}
#[test]
fn worked_example_canonical_realbug_extraction() {
let src = "\
import requests\n\
from flask import request\n\
def proxy_handler():\n\
\x20 url = request.json['url']\n\
\x20 return requests.get(url)\n";
let ev = extract(src, "get");
assert!(!ev.import_advocate);
assert_eq!(ev.api, Some(HttpApi::Requests));
assert!(ev.has_user_input_flow);
assert_eq!(ev.enclosing_function, Some("proxy_handler".to_string()));
}
#[test]
fn worked_example_canonical_advocate_safe_extraction() {
let src = "\
from advocate import Session\n\
def proxy_handler(request):\n\
\x20 url = request.body['url']\n\
\x20 s = Session()\n\
\x20 return s.get(url)\n";
let ev = extract(src, "get");
assert!(ev.import_advocate);
assert_eq!(ev.api, Some(HttpApi::Advocate));
assert_eq!(ev.enclosing_function, Some("proxy_handler".to_string()));
assert!(ev.has_user_input_flow);
}
#[test]
fn worked_example_canonical_advocate_direct_call() {
let src = "\
import advocate\n\
def f(url):\n\
\x20 return advocate.get(url)\n";
let ev = extract(src, "get");
assert_eq!(ev.api, Some(HttpApi::Advocate));
}
#[test]
fn url_concat_with_user_input_in_handler() {
let src = "\
import requests\n\
from flask import request\n\
def proxy_handler():\n\
\x20 host = request.json['host']\n\
\x20 return requests.get('http://' + host + '/api')\n";
let ev = extract(src, "get");
assert!(ev.has_user_input_flow);
assert!(ev.url_fstring_or_concat);
}
#[test]
fn fstring_with_user_input_in_handler() {
let src = "\
import requests\n\
from flask import request\n\
def proxy_handler():\n\
\x20 host = request.json['host']\n\
\x20 return requests.get(f'http://{host}/api')\n";
let ev = extract(src, "get");
assert!(ev.has_user_input_flow);
assert!(ev.url_fstring_or_concat);
}
#[test]
fn unused_advocate_import_pins_v0_limitation() {
let src = "\
import advocate # not used\n\
import requests\n\
def f(url):\n\
\x20 return requests.get(url)\n";
let ev = extract(src, "get");
assert!(ev.import_advocate);
assert_eq!(ev.api, Some(HttpApi::Requests));
}
#[test]
fn ssrf_safe_annotation_records_alongside_other_signals() {
let src = "\
import requests\n\
def proxy_handler(request):\n\
\x20 url = request.body['url']\n\
\x20 return requests.get(url) # repotoire: ssrf-safe[validated-by-cdn]\n";
let ev = extract(src, "get");
assert_eq!(
ev.ssrf_safe_annotation,
Some("validated-by-cdn".to_string())
);
assert!(ev.has_user_input_flow);
assert_eq!(ev.enclosing_function, Some("proxy_handler".to_string()));
assert_eq!(ev.api, Some(HttpApi::Requests));
}
}