use super::predict::{
classify_deserialize_callee, extract_deserialize_safe_reason,
extract_deserialize_vulnerable_source, line_contains_user_input,
matches_route_handler_decorator, matches_route_handler_name, matches_trust_boundary_name,
matches_upload_like_name, yaml_loader_is_safe, yaml_loader_is_unsafe, DeserializeApi, Evidence,
};
use crate::detectors::security::ast_helpers::{enclosing_python_function, node_text};
use std::collections::HashMap;
use tree_sitter::Node;
pub(super) struct PythonDeserializeSite<'a> {
pub call_node: Node<'a>,
pub api: DeserializeApi,
pub callee_label: String,
}
pub(super) fn collect_python_deserialize_sites<'a>(
module_root: Node<'a>,
source: &'a [u8],
) -> Vec<PythonDeserializeSite<'a>> {
let mut sites = Vec::new();
let mut stack: Vec<Node<'_>> = vec![module_root];
while let Some(node) = stack.pop() {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
stack.push(child);
}
if node.kind() != "call" {
continue;
}
let Some(func) = node.child_by_field_name("function") else {
continue;
};
let func_text = node_text(func, source).unwrap_or("");
if func_text.is_empty() {
continue;
}
if is_eval_call(func_text) {
continue;
}
let mut api = classify_deserialize_callee(func_text);
if !api.is_recognized() {
continue;
}
if matches!(api, DeserializeApi::Ambiguous) && is_yaml_load(func_text) {
api = reclassify_yaml_load(node, source);
}
sites.push(PythonDeserializeSite {
call_node: node,
api,
callee_label: func_text.to_string(),
});
}
sites
}
fn is_eval_call(func_text: &str) -> bool {
func_text == "eval"
}
fn is_yaml_load(func_text: &str) -> bool {
func_text == "yaml.load" || func_text.ends_with(".yaml.load")
}
fn reclassify_yaml_load<'a>(call_node: Node<'a>, source: &'a [u8]) -> DeserializeApi {
let Some(args) = call_node.child_by_field_name("arguments") else {
return DeserializeApi::Unsafe;
};
let kwargs = collect_keyword_arguments(args, source);
let Some(loader_val) = kwargs.get("Loader") else {
return DeserializeApi::Unsafe;
};
if yaml_loader_is_safe(loader_val) {
return DeserializeApi::Safe;
}
if yaml_loader_is_unsafe(loader_val) {
return DeserializeApi::Unsafe;
}
DeserializeApi::Ambiguous
}
pub(super) fn extract_python_evidence<'a>(
call_node: Node<'a>,
_module_root: Node<'a>,
source: &'a [u8],
lines: &[&str],
file_path: Option<String>,
api: DeserializeApi,
callee_label: String,
) -> Evidence {
let mut ev = Evidence {
file_path,
api: Some(api),
callee_label: Some(callee_label),
..Default::default()
};
let fn_node = enclosing_python_function(call_node);
if let Some(fn_node) = fn_node {
if let Some(name_node) = fn_node.child_by_field_name("name") {
if let Some(name) = node_text(name_node, source) {
ev.enclosing_function = Some(name.to_string());
}
}
}
ev.enclosing_class = enclosing_python_class_name(call_node, source);
if let Some(fn_name) = &ev.enclosing_function {
ev.enclosing_upload_like = matches_upload_like_name(fn_name);
ev.trust_boundary_name = matches_trust_boundary_name(fn_name);
if matches_route_handler_name(fn_name) {
ev.enclosing_route_handler = true;
}
}
if !ev.enclosing_route_handler {
if let Some(fn_node) = fn_node {
if function_has_route_decorator(fn_node, lines) {
ev.enclosing_route_handler = true;
}
}
}
let call_line = call_node.start_position().row;
ev.user_input_nearby = check_user_input_nearby(lines, call_line, 10);
ev.local_file_source = call_has_local_file_source(call_node, source);
if let Some(line) = lines.get(call_line) {
ev.deserialize_safe_annotation = extract_deserialize_safe_reason(line);
ev.deserialize_vulnerable_annotation = extract_deserialize_vulnerable_source(line);
}
ev
}
fn collect_keyword_arguments(args_node: Node<'_>, source: &[u8]) -> HashMap<String, String> {
let mut map = HashMap::new();
let mut cursor = args_node.walk();
for child in args_node.children(&mut cursor) {
if !child.is_named() {
continue;
}
if child.kind() != "keyword_argument" {
continue;
}
let name = child
.child_by_field_name("name")
.and_then(|n| node_text(n, source))
.map(str::to_string);
let value = child
.child_by_field_name("value")
.and_then(|n| node_text(n, source))
.map(str::to_string);
if let (Some(n), Some(v)) = (name, value) {
map.insert(n, v);
}
}
map
}
fn call_has_local_file_source(call_node: Node<'_>, source: &[u8]) -> bool {
let Some(args) = call_node.child_by_field_name("arguments") else {
return false;
};
let Some(text) = node_text(args, source) else {
return false;
};
text.contains("open(")
}
fn function_has_route_decorator(fn_node: Node<'_>, lines: &[&str]) -> bool {
let mut parent = fn_node.parent();
while let Some(p) = parent {
if p.kind() == "decorated_definition" {
let mut cursor = p.walk();
for child in p.children(&mut cursor) {
if child.kind() == "decorator" {
let line_idx = child.start_position().row;
if let Some(line) = lines.get(line_idx) {
if matches_route_handler_decorator(line) {
return true;
}
}
}
}
break;
}
if p.kind() == "module" {
break;
}
parent = p.parent();
}
false
}
fn enclosing_python_class_name<'a>(node: Node<'a>, source: &'a [u8]) -> Option<String> {
let mut cur = node.parent()?;
loop {
if cur.kind() == "class_definition" {
let name = cur.child_by_field_name("name")?;
return node_text(name, source).map(str::to_string);
}
if cur.kind() == "module" {
return None;
}
cur = cur.parent()?;
}
}
fn check_user_input_nearby(lines: &[&str], call_line: usize, radius: usize) -> bool {
let start = call_line.saturating_sub(radius);
let end = (call_line + radius + 1).min(lines.len());
for line in &lines[start..end] {
if line_contains_user_input(line) {
return true;
}
}
false
}
#[cfg(test)]
mod tests {
use super::*;
use crate::detectors::ast_fingerprint::parse_root_ext;
use crate::parsers::lightweight::Language;
fn first_call_with_attr<'tree>(
tree: &'tree tree_sitter::Tree,
source: &[u8],
attr_name: &str,
) -> tree_sitter::Node<'tree> {
fn walk<'a>(
node: tree_sitter::Node<'a>,
source: &[u8],
attr_name: &str,
) -> Option<tree_sitter::Node<'a>> {
if node.kind() == "call" {
if let Some(func) = node.child_by_field_name("function") {
let text = node_text(func, source).unwrap_or("");
let last = text.rsplit('.').next().unwrap_or("");
if last == attr_name {
return Some(node);
}
}
}
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if let Some(found) = walk(child, source, attr_name) {
return Some(found);
}
}
None
}
walk(tree.root_node(), source, attr_name)
.unwrap_or_else(|| panic!("no call to {} found in source", attr_name))
}
fn extract(src: &str, attr: &str) -> Evidence {
let tree = parse_root_ext(src, Language::Python, "py").expect("parse python");
let root = tree.root_node();
let call = first_call_with_attr(&tree, src.as_bytes(), attr);
let lines: Vec<&str> = src.lines().collect();
let func_text = call
.child_by_field_name("function")
.and_then(|f| node_text(f, src.as_bytes()))
.unwrap_or("")
.to_string();
let mut api = classify_deserialize_callee(&func_text);
if matches!(api, DeserializeApi::Ambiguous) && is_yaml_load(&func_text) {
api = reclassify_yaml_load(call, src.as_bytes());
}
extract_python_evidence(call, root, src.as_bytes(), &lines, None, api, func_text)
}
fn collect_sites(src: &str) -> Vec<DeserializeApi> {
let tree = parse_root_ext(src, Language::Python, "py").expect("parse python");
let root = tree.root_node();
collect_python_deserialize_sites(root, src.as_bytes())
.into_iter()
.map(|s| s.api)
.collect()
}
#[test]
fn collect_picks_up_pickle_loads() {
let sites = collect_sites("import pickle\npickle.loads(blob)\n");
assert_eq!(sites, vec![DeserializeApi::Unsafe]);
}
#[test]
fn collect_picks_up_yaml_safe_load() {
let sites = collect_sites("import yaml\nyaml.safe_load(blob)\n");
assert_eq!(sites, vec![DeserializeApi::Safe]);
}
#[test]
fn collect_picks_up_json_loads() {
let sites = collect_sites("import json\njson.loads(blob)\n");
assert_eq!(sites, vec![DeserializeApi::Safe]);
}
#[test]
fn collect_picks_up_marshal_loads_as_unsafe() {
let sites = collect_sites("import marshal\nmarshal.loads(blob)\n");
assert_eq!(sites, vec![DeserializeApi::Unsafe]);
}
#[test]
fn collect_picks_up_dill_loads_as_unsafe() {
let sites = collect_sites("import dill\ndill.loads(blob)\n");
assert_eq!(sites, vec![DeserializeApi::Unsafe]);
}
#[test]
fn collect_skips_eval_calls() {
let sites = collect_sites("eval(blob)\n");
assert!(sites.is_empty(), "eval() must be skipped, got {:?}", sites);
}
#[test]
fn collect_skips_unknown_callees() {
let sites = collect_sites("import json\njson.dumps({'a': 1})\nfoo.bar('x')\n");
assert!(sites.is_empty());
}
#[test]
fn collect_yaml_load_without_loader_is_unsafe() {
let sites = collect_sites("import yaml\nyaml.load(blob)\n");
assert_eq!(sites, vec![DeserializeApi::Unsafe]);
}
#[test]
fn collect_yaml_load_with_safeloader_is_safe() {
let sites = collect_sites("import yaml\nyaml.load(blob, Loader=yaml.SafeLoader)\n");
assert_eq!(sites, vec![DeserializeApi::Safe]);
}
#[test]
fn collect_yaml_load_with_csafeloader_is_safe() {
let sites = collect_sites("import yaml\nyaml.load(blob, Loader=yaml.CSafeLoader)\n");
assert_eq!(sites, vec![DeserializeApi::Safe]);
}
#[test]
fn collect_yaml_load_with_full_loader_is_unsafe() {
let sites = collect_sites("import yaml\nyaml.load(blob, Loader=yaml.FullLoader)\n");
assert_eq!(sites, vec![DeserializeApi::Unsafe]);
}
#[test]
fn collect_yaml_load_with_default_loader_is_unsafe() {
let sites = collect_sites("import yaml\nyaml.load(blob, Loader=yaml.Loader)\n");
assert_eq!(sites, vec![DeserializeApi::Unsafe]);
}
#[test]
fn collect_yaml_load_with_dynamic_loader_is_ambiguous() {
let sites = collect_sites(
"import yaml\n\
def f(blob, dynamic_loader):\n\
\x20 return yaml.load(blob, Loader=dynamic_loader)\n",
);
assert_eq!(sites, vec![DeserializeApi::Ambiguous]);
}
#[test]
fn detects_enclosing_function_name() {
let src = "\
import pickle\n\
def load_cache(path):\n\
\x20 return pickle.loads(open(path, 'rb').read())\n";
let ev = extract(src, "loads");
assert_eq!(ev.enclosing_function.as_deref(), Some("load_cache"));
}
#[test]
fn detects_enclosing_class_name() {
let src = "\
import pickle\n\
class Cache:\n\
\x20 def load(self, path):\n\
\x20 return pickle.loads(open(path, 'rb').read())\n";
let ev = extract(src, "loads");
assert_eq!(ev.enclosing_class.as_deref(), Some("Cache"));
}
#[test]
fn detects_upload_like_function_name() {
let src = "\
import yaml\n\
def upload_handler(blob):\n\
\x20 return yaml.load(blob, Loader=dyn)\n";
let ev = extract(src, "load");
assert!(ev.enclosing_upload_like);
assert!(ev.enclosing_route_handler);
}
#[test]
fn detects_load_function_name_as_upload_like() {
let src = "\
import yaml\n\
def load_payload(blob):\n\
\x20 return yaml.load(blob, Loader=dyn)\n";
let ev = extract(src, "load");
assert!(ev.enclosing_upload_like);
}
#[test]
fn detects_trust_boundary_name() {
let src = "\
import yaml\n\
def parse_trusted_config(blob):\n\
\x20 return yaml.load(blob, Loader=dyn)\n";
let ev = extract(src, "load");
assert!(ev.trust_boundary_name);
}
#[test]
fn detects_admin_in_name() {
let src = "\
import yaml\n\
def load_admin_settings(blob):\n\
\x20 return yaml.load(blob, Loader=dyn)\n";
let ev = extract(src, "load");
assert!(ev.trust_boundary_name);
}
#[test]
fn detects_flask_route_decorator() {
let src = "\
from flask import request\n\
import yaml\n\
@app.route('/config', methods=['POST'])\n\
def update(blob):\n\
\x20 return yaml.load(blob, Loader=dyn)\n";
let ev = extract(src, "load");
assert!(ev.enclosing_route_handler);
}
#[test]
fn detects_fastapi_router_decorator() {
let src = "\
import yaml\n\
@router.post('/config')\n\
def update(blob):\n\
\x20 return yaml.load(blob, Loader=dyn)\n";
let ev = extract(src, "load");
assert!(ev.enclosing_route_handler);
}
#[test]
fn no_route_handler_for_plain_function() {
let src = "\
import yaml\n\
def compute(blob):\n\
\x20 return yaml.load(blob, Loader=dyn)\n";
let ev = extract(src, "load");
assert!(!ev.enclosing_route_handler);
}
#[test]
fn detects_request_data_within_radius() {
let src = "\
import yaml\n\
def f(req):\n\
\x20 data = req.request.data\n\
\x20 return yaml.load(data, Loader=dyn)\n";
let ev = extract(src, "load");
assert!(ev.user_input_nearby);
}
#[test]
fn detects_request_json() {
let src = "\
import yaml\n\
def f():\n\
\x20 body = request.json\n\
\x20 return yaml.load(body, Loader=dyn)\n";
let ev = extract(src, "load");
assert!(ev.user_input_nearby);
}
#[test]
fn no_user_input_for_local_var() {
let src = "\
import yaml\n\
def f():\n\
\x20 data = compute()\n\
\x20 return yaml.load(data, Loader=dyn)\n";
let ev = extract(src, "load");
assert!(!ev.user_input_nearby);
}
#[test]
fn detects_open_in_pickle_call() {
let src = "\
import pickle\n\
def load_cache(path):\n\
\x20 return pickle.loads(open(path, 'rb').read())\n";
let ev = extract(src, "loads");
assert!(ev.local_file_source);
}
#[test]
fn no_local_file_for_plain_arg() {
let src = "\
import pickle\n\
def load_cache(blob):\n\
\x20 return pickle.loads(blob)\n";
let ev = extract(src, "loads");
assert!(!ev.local_file_source);
}
#[test]
fn extracts_deserialize_safe_annotation() {
let src = "\
import pickle\n\
def f(blob):\n\
\x20 return pickle.loads(blob) # repotoire: deserialize-safe[hmac-verified]\n";
let ev = extract(src, "loads");
assert_eq!(
ev.deserialize_safe_annotation.as_deref(),
Some("hmac-verified")
);
}
#[test]
fn extracts_deserialize_vulnerable_annotation() {
let src = "\
import json\n\
def f(blob):\n\
\x20 return json.loads(blob) # repotoire: deserialize-vulnerable[third-party]\n";
let ev = extract(src, "loads");
assert_eq!(
ev.deserialize_vulnerable_annotation.as_deref(),
Some("third-party")
);
}
#[test]
fn does_not_extract_unrelated_annotation() {
let src = "\
import pickle\n\
def f(blob):\n\
\x20 return pickle.loads(blob) # repotoire: jwt-safe[verified]\n";
let ev = extract(src, "loads");
assert!(ev.deserialize_safe_annotation.is_none());
assert!(ev.deserialize_vulnerable_annotation.is_none());
}
#[test]
fn evidence_records_safe_api_for_yaml_safe_load() {
let src = "import yaml\nyaml.safe_load(blob)\n";
let ev = extract(src, "safe_load");
assert_eq!(ev.api, Some(DeserializeApi::Safe));
}
#[test]
fn evidence_records_unsafe_api_for_pickle_loads() {
let src = "import pickle\npickle.loads(blob)\n";
let ev = extract(src, "loads");
assert_eq!(ev.api, Some(DeserializeApi::Unsafe));
}
#[test]
fn evidence_records_callee_label() {
let src = "import pickle\npickle.loads(blob)\n";
let ev = extract(src, "loads");
assert_eq!(ev.callee_label.as_deref(), Some("pickle.loads"));
}
#[test]
fn case_a_full_evidence_yaml_safe_load_in_handler() {
let src = "\
from flask import request\n\
import yaml\n\
@app.route('/config', methods=['POST'])\n\
def update_config():\n\
\x20 return yaml.safe_load(request.data)\n";
let ev = extract(src, "safe_load");
assert_eq!(ev.api, Some(DeserializeApi::Safe));
assert!(ev.user_input_nearby);
assert!(ev.enclosing_route_handler);
assert_eq!(ev.enclosing_function.as_deref(), Some("update_config"));
}
#[test]
fn case_b_full_evidence_pickle_loads_on_local_file() {
let src = "\
import pickle\n\
def load_cache():\n\
\x20 with open('/var/cache/sessions.pkl', 'rb') as f:\n\
\x20 return pickle.loads(f.read())\n";
let ev = extract(src, "loads");
assert_eq!(ev.api, Some(DeserializeApi::Unsafe));
assert!(!ev.local_file_source);
assert_eq!(ev.enclosing_function.as_deref(), Some("load_cache"));
}
#[test]
fn case_b_alt_full_evidence_pickle_loads_direct_open() {
let src = "\
import pickle\n\
def load_cache(path):\n\
\x20 return pickle.loads(open(path, 'rb').read())\n";
let ev = extract(src, "loads");
assert_eq!(ev.api, Some(DeserializeApi::Unsafe));
assert!(ev.local_file_source);
}
#[test]
fn case_d_full_evidence_yaml_load_with_safeloader_reclassifies() {
let src = "\
import yaml\n\
def f(blob):\n\
\x20 return yaml.load(blob, Loader=yaml.SafeLoader)\n";
let ev = extract(src, "load");
assert_eq!(ev.api, Some(DeserializeApi::Safe));
}
#[test]
fn case_d_alt_full_evidence_yaml_load_without_safeloader_is_unsafe() {
let src = "\
import yaml\n\
def f(blob):\n\
\x20 return yaml.load(blob)\n";
let ev = extract(src, "load");
assert_eq!(ev.api, Some(DeserializeApi::Unsafe));
}
}