use super::predict::{
classify_user_input_source, extract_nosql_safe_reason, extract_nosql_vulnerable_source,
line_contains_dangerous_operator, line_contains_developer_operator, line_contains_type_cast,
matches_route_handler_decorator, matches_route_handler_name, matches_trust_boundary_name,
Evidence, NosqlApi, UserInputSource, UNSTRUCTURED_JSON_USER_INPUT_SUBSTRINGS,
};
use crate::detectors::security::ast_helpers::{enclosing_python_function, node_text};
use tree_sitter::Node;
pub(super) const PYMONGO_QUERY_METHODS: &[&str] = &[
"find",
"find_one",
"find_by_id",
"find_one_and_update",
"find_one_and_delete",
"find_one_and_replace",
"update",
"update_one",
"update_many",
"replace_one",
"delete",
"delete_one",
"delete_many",
"aggregate",
"count",
"count_documents",
"estimated_document_count",
"distinct",
];
pub(super) struct PythonNosqlSite<'a> {
pub call_node: Node<'a>,
pub api: NosqlApi,
pub callee_label: String,
}
pub(super) fn collect_python_nosql_sites<'a>(
module_root: Node<'a>,
source: &'a [u8],
) -> Vec<PythonNosqlSite<'a>> {
let mut sites = Vec::new();
let mut stack: Vec<Node<'_>> = vec![module_root];
while let Some(node) = stack.pop() {
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
stack.push(child);
}
if node.kind() != "call" {
continue;
}
let Some(func) = node.child_by_field_name("function") else {
continue;
};
let func_text = node_text(func, source).unwrap_or("");
if func_text.is_empty() {
continue;
}
if !looks_like_pymongo_call(func_text) {
continue;
}
if receiver_looks_like_array(func_text) {
continue;
}
let api = classify_query_shape(node, source);
sites.push(PythonNosqlSite {
call_node: node,
api,
callee_label: func_text.to_string(),
});
}
sites
}
fn looks_like_pymongo_call(func_text: &str) -> bool {
let Some(method_name) = func_text.rsplit('.').next() else {
return false;
};
if method_name == func_text {
return false;
}
PYMONGO_QUERY_METHODS.contains(&method_name)
}
fn receiver_looks_like_array(func_text: &str) -> bool {
let parts: Vec<&str> = func_text.rsplitn(2, '.').collect();
if parts.len() != 2 {
return false;
}
let receiver = parts[1].to_lowercase();
const ARRAY_RECEIVER_SUBSTRINGS: &[&str] = &[
"items.", "items", "list", "array", "results", "options", "elements", "entries", "records",
"rows", "values", "keys",
];
let last_receiver = receiver.rsplit('.').next().unwrap_or(&receiver);
ARRAY_RECEIVER_SUBSTRINGS
.iter()
.any(|s| last_receiver == *s || last_receiver == s.trim_end_matches('.'))
}
fn classify_query_shape(call_node: Node<'_>, source: &[u8]) -> NosqlApi {
let Some(args) = call_node.child_by_field_name("arguments") else {
return NosqlApi::Ambiguous;
};
let Some(first_arg) = first_positional_arg(args) else {
return NosqlApi::Ambiguous;
};
if dict_has_user_input_splat(first_arg, source) {
return NosqlApi::DictExpansion;
}
let dict_text = node_text(first_arg, source).unwrap_or("");
if has_dangerous_operator_with_user_input(first_arg, source, dict_text) {
return NosqlApi::OperatorInjection;
}
if first_arg.kind() == "dictionary"
&& !line_contains_dangerous_operator(dict_text)
&& is_typed_value_query(first_arg, source)
{
return NosqlApi::TypedValueQuery;
}
NosqlApi::Ambiguous
}
fn first_positional_arg<'a>(args: Node<'a>) -> Option<Node<'a>> {
let mut cursor = args.walk();
for child in args.children(&mut cursor) {
if !child.is_named() {
continue;
}
if child.kind() == "keyword_argument" {
continue;
}
return Some(child);
}
None
}
fn dict_has_user_input_splat(node: Node<'_>, source: &[u8]) -> bool {
if node.kind() != "dictionary" {
return false;
}
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if child.kind() != "dictionary_splat" {
continue;
}
let splat_text = node_text(child, source).unwrap_or("");
let expr = splat_text.trim_start_matches('*').trim();
if is_user_input_identifier(expr) {
return true;
}
}
false
}
fn is_user_input_identifier(expr: &str) -> bool {
let lower = expr.to_lowercase();
UNSTRUCTURED_JSON_USER_INPUT_SUBSTRINGS
.iter()
.any(|s| lower.contains(s))
}
fn has_dangerous_operator_with_user_input(_arg: Node<'_>, _source: &[u8], dict_text: &str) -> bool {
if !line_contains_dangerous_operator(dict_text) {
return false;
}
let lower = dict_text.to_lowercase();
for s in UNSTRUCTURED_JSON_USER_INPUT_SUBSTRINGS {
if lower.contains(s) {
return true;
}
}
for s in super::predict::TYPED_STRING_USER_INPUT_SUBSTRINGS {
if lower.contains(s) {
return true;
}
}
false
}
fn is_typed_value_query(dict_node: Node<'_>, source: &[u8]) -> bool {
let mut cursor = dict_node.walk();
for child in dict_node.children(&mut cursor) {
if child.kind() != "pair" {
continue;
}
let Some(value) = child.child_by_field_name("value") else {
continue;
};
let value_text = node_text(value, source).unwrap_or("");
if value_contains_user_input(value_text) && !value_contains_type_cast(value_text) {
return false;
}
if value.kind() == "dictionary" && !is_typed_value_query(value, source) {
return false;
}
}
true
}
fn value_contains_user_input(text: &str) -> bool {
let lower = text.to_lowercase();
for s in UNSTRUCTURED_JSON_USER_INPUT_SUBSTRINGS {
if lower.contains(s) {
return true;
}
}
for s in super::predict::TYPED_STRING_USER_INPUT_SUBSTRINGS {
if lower.contains(s) {
return true;
}
}
false
}
fn value_contains_type_cast(text: &str) -> bool {
const VALUE_CAST_SUBSTRINGS: &[&str] = &[
"str(",
"ObjectId(",
"bson.ObjectId(",
"int(",
"float(",
"bool(",
"UUID(",
"uuid.UUID(",
".parse_obj(",
".model_validate(",
"schema.load(",
];
VALUE_CAST_SUBSTRINGS.iter().any(|s| text.contains(s))
}
pub(super) fn extract_python_evidence<'a>(
call_node: Node<'a>,
_module_root: Node<'a>,
source: &'a [u8],
lines: &[&str],
file_path: Option<String>,
api: NosqlApi,
callee_label: String,
) -> Evidence {
let mut ev = Evidence {
file_path,
api: Some(api),
callee_label: Some(callee_label),
..Default::default()
};
let fn_node = enclosing_python_function(call_node);
if let Some(fn_node) = fn_node {
if let Some(name_node) = fn_node.child_by_field_name("name") {
if let Some(name) = node_text(name_node, source) {
ev.enclosing_function = Some(name.to_string());
}
}
}
ev.enclosing_class = enclosing_python_class_name(call_node, source);
if let Some(fn_name) = &ev.enclosing_function {
ev.trust_boundary_name = matches_trust_boundary_name(fn_name);
if matches_route_handler_name(fn_name) {
ev.enclosing_route_handler = true;
}
}
if !ev.enclosing_route_handler {
if let Some(fn_node) = fn_node {
if function_has_route_decorator(fn_node, lines) {
ev.enclosing_route_handler = true;
}
}
}
let call_line = call_node.start_position().row;
ev.user_input_source = classify_nearby_user_input(lines, call_line, 10);
ev.type_cast_nearby = type_cast_nearby(lines, call_line, 5);
if let Some(line) = lines.get(call_line) {
ev.has_dollar_regex_with_user_input =
line.contains("$regex") && line_user_input_present(line);
}
ev.has_developer_written_operator = has_developer_written_operator(call_node, source);
if let Some(line) = lines.get(call_line) {
ev.nosql_safe_annotation = extract_nosql_safe_reason(line);
ev.nosql_vulnerable_annotation = extract_nosql_vulnerable_source(line);
}
ev
}
fn line_user_input_present(line: &str) -> bool {
!matches!(classify_user_input_source(line), UserInputSource::None)
}
fn classify_nearby_user_input(lines: &[&str], call_line: usize, radius: usize) -> UserInputSource {
let start = call_line.saturating_sub(radius);
let end = (call_line + radius + 1).min(lines.len());
for line in &lines[start..end] {
if matches!(
classify_user_input_source(line),
UserInputSource::UnstructuredJson
) {
return UserInputSource::UnstructuredJson;
}
}
for line in &lines[start..end] {
if matches!(
classify_user_input_source(line),
UserInputSource::TypedString
) {
return UserInputSource::TypedString;
}
}
UserInputSource::None
}
fn type_cast_nearby(lines: &[&str], call_line: usize, radius: usize) -> bool {
let start = call_line.saturating_sub(radius);
let end = (call_line + radius + 1).min(lines.len());
for line in &lines[start..end] {
if line_contains_type_cast(line) {
return true;
}
}
false
}
fn has_developer_written_operator(call_node: Node<'_>, source: &[u8]) -> bool {
let Some(args) = call_node.child_by_field_name("arguments") else {
return false;
};
let Some(first_arg) = first_positional_arg(args) else {
return false;
};
let dict_text = node_text(first_arg, source).unwrap_or("");
if !line_contains_developer_operator(dict_text) {
return false;
}
!value_contains_user_input(dict_text)
}
fn function_has_route_decorator(fn_node: Node<'_>, lines: &[&str]) -> bool {
let mut parent = fn_node.parent();
while let Some(p) = parent {
if p.kind() == "decorated_definition" {
let mut cursor = p.walk();
for child in p.children(&mut cursor) {
if child.kind() == "decorator" {
let line_idx = child.start_position().row;
if let Some(line) = lines.get(line_idx) {
if matches_route_handler_decorator(line) {
return true;
}
}
}
}
break;
}
if p.kind() == "module" {
break;
}
parent = p.parent();
}
false
}
fn enclosing_python_class_name<'a>(node: Node<'a>, source: &'a [u8]) -> Option<String> {
let mut cur = node.parent()?;
loop {
if cur.kind() == "class_definition" {
let name = cur.child_by_field_name("name")?;
return node_text(name, source).map(str::to_string);
}
if cur.kind() == "module" {
return None;
}
cur = cur.parent()?;
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::detectors::ast_fingerprint::parse_root_ext;
use crate::parsers::lightweight::Language;
fn first_call_with_attr<'tree>(
tree: &'tree tree_sitter::Tree,
source: &[u8],
attr_name: &str,
) -> tree_sitter::Node<'tree> {
fn walk<'a>(
node: tree_sitter::Node<'a>,
source: &[u8],
attr_name: &str,
) -> Option<tree_sitter::Node<'a>> {
if node.kind() == "call" {
if let Some(func) = node.child_by_field_name("function") {
let text = node_text(func, source).unwrap_or("");
let last = text.rsplit('.').next().unwrap_or("");
if last == attr_name {
return Some(node);
}
}
}
let mut cursor = node.walk();
for child in node.children(&mut cursor) {
if let Some(found) = walk(child, source, attr_name) {
return Some(found);
}
}
None
}
walk(tree.root_node(), source, attr_name)
.unwrap_or_else(|| panic!("no call to {} found in source", attr_name))
}
fn extract(src: &str, attr: &str) -> Evidence {
let tree = parse_root_ext(src, Language::Python, "py").expect("parse python");
let root = tree.root_node();
let call = first_call_with_attr(&tree, src.as_bytes(), attr);
let lines: Vec<&str> = src.lines().collect();
let func_text = call
.child_by_field_name("function")
.and_then(|f| node_text(f, src.as_bytes()))
.unwrap_or("")
.to_string();
let api = classify_query_shape(call, src.as_bytes());
extract_python_evidence(call, root, src.as_bytes(), &lines, None, api, func_text)
}
fn collect_sites(src: &str) -> Vec<(NosqlApi, String)> {
let tree = parse_root_ext(src, Language::Python, "py").expect("parse python");
let root = tree.root_node();
collect_python_nosql_sites(root, src.as_bytes())
.into_iter()
.map(|s| (s.api, s.callee_label))
.collect()
}
#[test]
fn collect_picks_up_find_one() {
let sites = collect_sites("users.find_one({})\n");
assert_eq!(sites.len(), 1);
assert_eq!(sites[0].1, "users.find_one");
}
#[test]
fn collect_picks_up_db_collection_aggregate() {
let sites = collect_sites("db.users.aggregate([])\n");
assert_eq!(sites.len(), 1);
assert_eq!(sites[0].1, "db.users.aggregate");
}
#[test]
fn collect_skips_bare_find_call() {
let sites = collect_sites("find()\n");
assert!(sites.is_empty(), "bare find() must not be picked up");
}
#[test]
fn collect_skips_items_find_array_method() {
let sites = collect_sites("items.find(predicate)\n");
assert!(sites.is_empty(), "items.find must be skipped (array FP)");
}
#[test]
fn collect_skips_list_find_array_method() {
let sites = collect_sites("results.find(predicate)\n");
assert!(sites.is_empty(), "results.find must be skipped");
}
#[test]
fn classify_dict_expansion_request_json() {
let sites = collect_sites("users.find_one({**request.json})\n");
assert_eq!(sites[0].0, NosqlApi::DictExpansion);
}
#[test]
fn classify_dict_expansion_request_body() {
let sites = collect_sites("users.find({**request.body})\n");
assert_eq!(sites[0].0, NosqlApi::DictExpansion);
}
#[test]
fn classify_dict_expansion_get_json() {
let sites = collect_sites("users.find_one({**request.get_json()})\n");
assert_eq!(sites[0].0, NosqlApi::DictExpansion);
}
#[test]
fn classify_dict_expansion_with_other_keys() {
let sites = collect_sites("users.find_one({**request.json, \"active\": True})\n");
assert_eq!(sites[0].0, NosqlApi::DictExpansion);
}
#[test]
fn classify_where_with_user_input() {
let src = "users.find_one({\"$where\": f\"this.x == '{request.form['x']}'\"})\n";
let sites = collect_sites(src);
assert_eq!(sites[0].0, NosqlApi::OperatorInjection);
}
#[test]
fn classify_function_operator_with_user_input() {
let src = "users.aggregate([{\"$function\": {\"body\": request.json['code']}}])\n";
let sites = collect_sites(src);
assert_eq!(sites[0].0, NosqlApi::OperatorInjection);
}
#[test]
fn classify_where_without_user_input_is_ambiguous() {
let src = "users.find({\"$where\": \"this.x == 1\"})\n";
let sites = collect_sites(src);
assert_eq!(sites[0].0, NosqlApi::Ambiguous);
}
#[test]
fn classify_typed_value_query_with_str_cast() {
let src = "users.find_one({\"username\": str(request.form['user'])})\n";
let sites = collect_sites(src);
assert_eq!(sites[0].0, NosqlApi::TypedValueQuery);
}
#[test]
fn classify_typed_value_query_with_objectid_cast() {
let src = "users.find({\"_id\": ObjectId(request.form['id'])})\n";
let sites = collect_sites(src);
assert_eq!(sites[0].0, NosqlApi::TypedValueQuery);
}
#[test]
fn classify_typed_value_query_with_no_user_input() {
let src = "users.find({\"active\": True, \"role\": \"admin\"})\n";
let sites = collect_sites(src);
assert_eq!(sites[0].0, NosqlApi::TypedValueQuery);
}
#[test]
fn classify_typed_value_query_with_developer_operator() {
let src = "users.find({\"role\": {\"$ne\": \"admin\"}})\n";
let sites = collect_sites(src);
assert_eq!(sites[0].0, NosqlApi::TypedValueQuery);
}
#[test]
fn classify_naked_request_json_value_is_ambiguous() {
let src = "users.find_one({\"username\": request.json['user']})\n";
let sites = collect_sites(src);
assert_eq!(sites[0].0, NosqlApi::Ambiguous);
}
#[test]
fn classify_identifier_arg_is_ambiguous() {
let src = "q = build_query()\nusers.find(q)\n";
let sites = collect_sites(src);
let pymongo: Vec<_> = sites.iter().filter(|(_, l)| l == "users.find").collect();
assert_eq!(pymongo.len(), 1);
assert_eq!(pymongo[0].0, NosqlApi::Ambiguous);
}
#[test]
fn detects_enclosing_function_name() {
let src = "\
def get_user(user_id):\n\
\x20 return users.find_one({\"_id\": ObjectId(user_id)})\n";
let ev = extract(src, "find_one");
assert_eq!(ev.enclosing_function.as_deref(), Some("get_user"));
}
#[test]
fn detects_enclosing_class_name() {
let src = "\
class UserRepo:\n\
\x20 def get(self, uid):\n\
\x20 return users.find_one({\"_id\": ObjectId(uid)})\n";
let ev = extract(src, "find_one");
assert_eq!(ev.enclosing_class.as_deref(), Some("UserRepo"));
}
#[test]
fn detects_flask_route_decorator() {
let src = "\
from flask import request\n\
@app.route('/login', methods=['POST'])\n\
def login():\n\
\x20 return users.find_one({\"name\": request.form['n']})\n";
let ev = extract(src, "find_one");
assert!(ev.enclosing_route_handler);
}
#[test]
fn detects_handler_name() {
let src = "\
def login_handler():\n\
\x20 return users.find_one({})\n";
let ev = extract(src, "find_one");
assert!(ev.enclosing_route_handler);
}
#[test]
fn detects_unstructured_json_source() {
let src = "\
def f():\n\
\x20 body = request.json\n\
\x20 return users.find_one({\"x\": body['x']})\n";
let ev = extract(src, "find_one");
assert_eq!(ev.user_input_source, UserInputSource::UnstructuredJson);
}
#[test]
fn detects_typed_string_source() {
let src = "\
def f():\n\
\x20 n = request.form['name']\n\
\x20 return users.find_one({\"name\": str(n)})\n";
let ev = extract(src, "find_one");
assert_eq!(ev.user_input_source, UserInputSource::TypedString);
}
#[test]
fn detects_no_user_input_source() {
let src = "\
def f():\n\
\x20 return users.find_one({\"active\": True})\n";
let ev = extract(src, "find_one");
assert_eq!(ev.user_input_source, UserInputSource::None);
}
#[test]
fn detects_objectid_cast_nearby() {
let src = "\
def f(uid):\n\
\x20 oid = ObjectId(uid)\n\
\x20 return users.find_one({\"_id\": oid})\n";
let ev = extract(src, "find_one");
assert!(ev.type_cast_nearby);
}
#[test]
fn detects_pydantic_validation_nearby() {
let src = "\
def f(payload):\n\
\x20 q = QuerySchema.model_validate(payload)\n\
\x20 return users.find_one({\"name\": q.name})\n";
let ev = extract(src, "find_one");
assert!(ev.type_cast_nearby);
}
#[test]
fn detects_developer_written_ne() {
let src = "users.find({\"role\": {\"$ne\": \"admin\"}})\n";
let ev = extract(src, "find");
assert!(ev.has_developer_written_operator);
}
#[test]
fn does_not_detect_developer_written_when_user_input_present() {
let src = "users.find({\"role\": {\"$ne\": request.form['r']}})\n";
let ev = extract(src, "find");
assert!(!ev.has_developer_written_operator);
}
#[test]
fn does_not_detect_developer_written_for_plain_query() {
let src = "users.find({\"role\": \"admin\"})\n";
let ev = extract(src, "find");
assert!(!ev.has_developer_written_operator);
}
#[test]
fn extracts_nosql_safe_annotation() {
let src = "\
def f():\n\
\x20 return users.find_one(q) # repotoire: nosql-safe[pydantic-validated]\n";
let ev = extract(src, "find_one");
assert_eq!(
ev.nosql_safe_annotation.as_deref(),
Some("pydantic-validated")
);
}
#[test]
fn extracts_nosql_vulnerable_annotation() {
let src = "\
def f():\n\
\x20 return users.find_one(q) # repotoire: nosql-vulnerable[helper-built]\n";
let ev = extract(src, "find_one");
assert_eq!(
ev.nosql_vulnerable_annotation.as_deref(),
Some("helper-built")
);
}
#[test]
fn does_not_extract_unrelated_annotation() {
let src = "\
def f():\n\
\x20 return users.find_one({}) # repotoire: jwt-safe[verified]\n";
let ev = extract(src, "find_one");
assert!(ev.nosql_safe_annotation.is_none());
assert!(ev.nosql_vulnerable_annotation.is_none());
}
#[test]
fn case_a_full_evidence_naked_request_json() {
let src = "\
from flask import request\n\
@app.route('/login', methods=['POST'])\n\
def login():\n\
\x20 return users.find_one({\"username\": request.json['user']})\n";
let ev = extract(src, "find_one");
assert_eq!(ev.api, Some(NosqlApi::Ambiguous));
assert_eq!(ev.user_input_source, UserInputSource::UnstructuredJson);
assert!(ev.enclosing_route_handler);
}
#[test]
fn case_b_full_evidence_str_cast() {
let src = "\
from flask import request\n\
def login():\n\
\x20 return users.find_one({\"username\": str(request.form['user'])})\n";
let ev = extract(src, "find_one");
assert_eq!(ev.api, Some(NosqlApi::TypedValueQuery));
assert_eq!(ev.user_input_source, UserInputSource::TypedString);
}
#[test]
fn case_c_full_evidence_where_with_user_input() {
let src = "\
from flask import request\n\
def f():\n\
\x20 return users.find_one({\"$where\": f\"this.x == '{request.form['x']}'\"})\n";
let ev = extract(src, "find_one");
assert_eq!(ev.api, Some(NosqlApi::OperatorInjection));
}
#[test]
fn case_d_full_evidence_developer_written_ne() {
let src = "\
def list_users():\n\
\x20 return users.find({\"role\": {\"$ne\": \"admin\"}})\n";
let ev = extract(src, "find");
assert_eq!(ev.api, Some(NosqlApi::TypedValueQuery));
assert!(ev.has_developer_written_operator);
}
#[test]
fn case_e_full_evidence_dict_expansion() {
let src = "\
from flask import request\n\
@app.route('/q', methods=['POST'])\n\
def f():\n\
\x20 return users.find_one({**request.get_json()})\n";
let ev = extract(src, "find_one");
assert_eq!(ev.api, Some(NosqlApi::DictExpansion));
assert!(ev.enclosing_route_handler);
}
#[test]
fn case_f_full_evidence_objectid_cast() {
let src = "\
from flask import request\n\
def f():\n\
\x20 return users.find({\"_id\": ObjectId(request.form['id'])})\n";
let ev = extract(src, "find");
assert_eq!(ev.api, Some(NosqlApi::TypedValueQuery));
assert!(ev.type_cast_nearby);
}
}