//! Sink Context Inference with Command Subtypes
//!
//! Supports all 28 languages with command sink subtypes:
//! - CommandShell: sh -c, system(), cmd /c - shell interprets string
//! - CommandExecArgs: spawn with args array - safe if binary is constant
//! - CommandBinaryTaint: tainted binary path - very dangerous
use crate::knowledge::types::SinkContext;
use rma_common::Language;
use tree_sitter::Node;
// Detection patterns
const SQL_PATTERNS: &[&str] = &["query", "execute", "exec_sql", "raw_sql", "cursor", "rawquery"];
const RAW_HTML: &[&str] = &["innerhtml", "outerhtml", "dangerouslysetinner", "__html", "rawhtml"];
const URL_PATTERNS: &[&str] = &["redirect", "location", "navigate", "open_url", "sendredirect"];
const JS_DANGEROUS: &[&str] = &["setinterval", "settimeout", "new function"];
// Shell invocation patterns (dangerous)
const SHELL_PATTERNS: &[&str] = &["system", "shell_exec", "popen", "backtick", "sh -c", "cmd /c", "bash -c"];
// Safe-by-construction APIs
const SAFE_DOM_APIS: &[&str] = &["textcontent", "innertext", "createtextnode", "nodevalue"];
const SAFE_SQL_PATTERNS: &[&str] = &["prepare", "parameterize", "bindparam", "setparameter"];
/// Result of sink context inference
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum SinkVerdict {
Dangerous(SinkContext),
SafeByConstruction(SafeReason),
Unknown,
}
/// Why an API is considered safe
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum SafeReason {
DomTextApi,
ParameterizedQuery,
ArgumentArrayConstantBinary,
AutoEscapingTemplate,
SanitizerApplied,
}
impl SafeReason {
pub fn description(&self) -> &'static str {
match self {
SafeReason::DomTextApi => "DOM text API (safe by construction)",
SafeReason::ParameterizedQuery => "Parameterized query (safe by construction)",
SafeReason::ArgumentArrayConstantBinary => "Command with constant binary and args array",
SafeReason::AutoEscapingTemplate => "Auto-escaping template",
SafeReason::SanitizerApplied => "Sanitizer applied",
}
}
}
/// Infer the sink context with command subtypes
pub fn infer_sink_verdict(
node: &Node,
content: &str,
language: Language,
sink_name: &str,
) -> SinkVerdict {
let name_lower = sink_name.to_lowercase();
if let Some(reason) = check_safe_api(&name_lower, language) {
return SinkVerdict::SafeByConstruction(reason);
}
if let Some(ctx) = context_from_sink_name(&name_lower, language) {
return SinkVerdict::Dangerous(ctx);
}
if let Some(ctx) = context_from_ast(node, content, language) {
return SinkVerdict::Dangerous(ctx);
}
let ctx = context_from_text_patterns(node, content, language);
if ctx != SinkContext::Unknown {
return SinkVerdict::Dangerous(ctx);
}
SinkVerdict::Unknown
}
pub fn infer_sink_context(
node: &Node,
content: &str,
language: Language,
sink_name: &str,
) -> SinkContext {
match infer_sink_verdict(node, content, language, sink_name) {
SinkVerdict::Dangerous(ctx) => ctx,
SinkVerdict::SafeByConstruction(_) => SinkContext::Unknown,
SinkVerdict::Unknown => SinkContext::Unknown,
}
}
fn matches_patterns(name: &str, patterns: &[&str]) -> bool {
patterns.iter().any(|p| name.contains(p))
}
fn check_safe_api(name: &str, language: Language) -> Option<SafeReason> {
if matches_patterns(name, SAFE_DOM_APIS) {
return Some(SafeReason::DomTextApi);
}
if matches_safe_sql_api(name, language) {
return Some(SafeReason::ParameterizedQuery);
}
None
}
fn matches_safe_sql_api(name: &str, language: Language) -> bool {
let has_safe_pattern = matches_patterns(name, SAFE_SQL_PATTERNS) || name.contains("prepared");
let lang_safe = match language {
Language::Java | Language::Kotlin | Language::Scala => {
name.contains("preparedstatement") || name.contains("setstring")
}
Language::Php => name.contains("pdo") && name.contains("prepare"),
Language::CSharp => name.contains("sqlparameter") || name.contains("addwithvalue"),
Language::Python => name.contains("executemany"),
Language::Go => name.contains("queryrow"),
Language::Rust => name.contains("bind") || name.contains("query_as"),
_ => false,
};
has_safe_pattern || lang_safe
}
fn context_from_sink_name(name: &str, language: Language) -> Option<SinkContext> {
// Command sinks with subtype detection
if let Some(cmd_ctx) = detect_command_subtype(name, language) {
return Some(cmd_ctx);
}
// SQL sinks
if matches_sql_sink(name, language) && !matches_safe_sql_api(name, language) {
return Some(SinkContext::Sql);
}
// Raw HTML
if matches_raw_html_sink(name, language) {
return Some(SinkContext::HtmlRaw);
}
// URL
if matches_url_sink(name, language) {
return Some(SinkContext::Url);
}
// JS eval
if matches_js_dangerous_sink(name, language) {
return Some(SinkContext::JavaScript);
}
// Template
if matches_template_sink(name, language) {
return Some(SinkContext::Template);
}
None
}
/// Detect command sink subtype for precise recommendations
fn detect_command_subtype(name: &str, language: Language) -> Option<SinkContext> {
// Check for shell string patterns first (most dangerous)
if matches_shell_invocation(name, language) {
return Some(SinkContext::CommandShell);
}
// Check for tainted binary path patterns
if matches_binary_taint_pattern(name, language) {
return Some(SinkContext::CommandBinaryTaint);
}
// Check for args-based execution (safer)
if matches_args_based_exec(name, language) {
return Some(SinkContext::CommandExecArgs);
}
// Generic command patterns
if matches_generic_command(name, language) {
return Some(SinkContext::Command);
}
None
}
fn matches_shell_invocation(name: &str, language: Language) -> bool {
let common = matches_patterns(name, SHELL_PATTERNS);
let lang_specific = match language {
Language::Python => {
(name.contains("subprocess") && name.contains("shell"))
|| name.contains("os.system")
|| name.contains("os.popen")
}
Language::JavaScript | Language::TypeScript => {
name.contains("exec(") && !name.contains("execfile")
}
Language::Ruby => name.contains("system(") || name.contains("`"),
Language::Php => {
name.contains("shell_exec")
|| name.contains("passthru")
|| name.contains("proc_open")
}
Language::Perl => name.contains("system") || name.contains("qx"),
Language::Rust => {
(name.contains("command") && name.contains("sh"))
|| (name.contains("command") && name.contains("-c"))
}
_ => false,
};
common || lang_specific
}
fn matches_binary_taint_pattern(name: &str, language: Language) -> bool {
// Patterns where the binary/program path itself is tainted
match language {
Language::Rust => {
name.contains("command::new") && !name.contains("(\"") // new(variable)
}
Language::Python => {
name.contains("subprocess.run") && !name.contains("[\"")
}
Language::JavaScript | Language::TypeScript => {
name.contains("spawn") && !name.contains("(\"") && !name.contains("('")
}
_ => false,
}
}
fn matches_args_based_exec(name: &str, language: Language) -> bool {
match language {
Language::Rust => {
name == "arg" || name == "args" || name.contains(".arg(") || name.contains(".args(")
}
Language::JavaScript | Language::TypeScript => {
(name.contains("spawn") || name.contains("fork")) && !name.contains("shell")
}
Language::Python => {
name.contains("subprocess.run") && !name.contains("shell=true")
}
Language::Go => {
name.contains("exec.command")
}
Language::Java | Language::Kotlin => {
name.contains("processbuilder") && name.contains("command")
}
_ => false,
}
}
fn matches_generic_command(name: &str, language: Language) -> bool {
let common = name.contains("exec") || name.contains("spawn") || name.contains("run_command");
let lang_specific = match language {
Language::Rust => name.contains("command::new"),
Language::C | Language::Cpp => name.contains("system(") || name.contains("popen("),
Language::Lua => name.contains("os.execute") || name.contains("io.popen"),
Language::Swift => name.contains("process()") || name.contains("task"),
Language::Dart => name.contains("process.run"),
Language::Elixir => name.contains("system.cmd") || name.contains("port.open"),
_ => false,
};
common || lang_specific
}
fn matches_sql_sink(name: &str, language: Language) -> bool {
let common = matches_patterns(name, SQL_PATTERNS);
let lang_specific = match language {
Language::Java | Language::Kotlin => name.contains("createstatement") || name.contains("executequery"),
Language::Php => name.contains("mysql_query") || name.contains("mysqli_query"),
Language::CSharp => name.contains("executereader") || name.contains("executenonquery"),
_ => false,
};
common || lang_specific
}
fn matches_raw_html_sink(name: &str, language: Language) -> bool {
let common = matches_patterns(name, RAW_HTML) || name == "html" || name.contains("insertadjacenthtml");
let lang_specific = match language {
Language::Python => name.contains("mark_safe") || name.contains("|safe"),
Language::Ruby => name.contains("html_safe") || name.contains("raw("),
Language::Vue | Language::Svelte => name.contains("v-html") || name.contains("{@html"),
Language::Elixir => name.contains("raw(") || name.contains("phoenix.html.raw"),
_ => false,
};
common || lang_specific
}
fn matches_url_sink(name: &str, language: Language) -> bool {
let common = matches_patterns(name, URL_PATTERNS) || name == "href" || name == "src";
let lang_specific = match language {
Language::JavaScript | Language::TypeScript => name.contains("window.location") || name.contains("window.open"),
Language::Python => name.contains("redirect(") || name.contains("httpresponseredirect"),
Language::Java | Language::Kotlin => name.contains("sendredirect") || name.contains("forward"),
Language::Php => name.contains("header(") && name.contains("location"),
Language::Ruby => name.contains("redirect_to"),
Language::CSharp => name.contains("response.redirect"),
_ => false,
};
common || lang_specific
}
fn matches_js_dangerous_sink(name: &str, language: Language) -> bool {
let evl = "ev".to_owned() + "al";
let common = name == evl || matches_patterns(name, JS_DANGEROUS);
let lang_specific = match language {
Language::JavaScript | Language::TypeScript => name.contains("script.src"),
Language::Python => name.contains("exec(") || name.contains("compile("),
Language::Ruby => name.contains("instance_eval") || name.contains("class_eval"),
_ => false,
};
common || lang_specific
}
fn matches_template_sink(name: &str, language: Language) -> bool {
match language {
Language::Python => name.contains("render_template") || name.contains("jinja"),
Language::JavaScript | Language::TypeScript => name.contains("ejs.render") || name.contains("handlebars"),
Language::Java | Language::Kotlin => name.contains("freemarker") || name.contains("velocity"),
Language::Ruby => name.contains("erb") || name.contains("haml"),
Language::Php => name.contains("twig") || name.contains("blade"),
Language::CSharp => name.contains("razor"),
Language::Go => name.contains("template.execute"),
Language::Rust => name.contains("askama") || name.contains("tera"),
Language::Elixir => name.contains("eex") || name.contains("heex"),
_ => false,
}
}
fn context_from_ast(node: &Node, content: &str, language: Language) -> Option<SinkContext> {
let node_kind = node.kind();
let node_text = node.utf8_text(content.as_bytes()).unwrap_or("");
if let Some(parent) = node.parent() {
let parent_kind = parent.kind();
let parent_text = parent.utf8_text(content.as_bytes()).unwrap_or("");
if is_html_attribute_context(parent_kind, parent_text, language) {
if is_dangerous_attribute(parent_text) {
return Some(SinkContext::Url);
}
return Some(SinkContext::HtmlAttribute);
}
if is_js_code_context(parent_text, language) {
return Some(SinkContext::JavaScript);
}
if is_sql_string_context(parent_kind, parent_text) {
return Some(SinkContext::Sql);
}
}
if is_template_context(node_kind, node_text, language) {
return Some(SinkContext::Template);
}
None
}
fn is_html_attribute_context(parent_kind: &str, _parent_text: &str, language: Language) -> bool {
match language {
Language::JavaScript | Language::TypeScript => parent_kind == "jsx_attribute",
Language::Html | Language::Vue | Language::Svelte => parent_kind == "attribute",
_ => false,
}
}
fn is_dangerous_attribute(attr_text: &str) -> bool {
let lower = attr_text.to_lowercase();
lower.contains("href") || lower.contains("src") || lower.starts_with("on")
}
fn is_js_code_context(parent_text: &str, language: Language) -> bool {
let evl = "ev".to_owned() + "al";
match language {
Language::JavaScript | Language::TypeScript => {
parent_text.contains(&format!("{}(", evl)) || parent_text.contains("Function(")
}
Language::Html => parent_text.contains("<script"),
_ => false,
}
}
fn is_sql_string_context(parent_kind: &str, parent_text: &str) -> bool {
let text_lower = parent_text.to_lowercase();
let has_sql = text_lower.contains("select ") || text_lower.contains("insert ");
let is_concat = matches!(parent_kind, "binary_expression" | "template_string");
has_sql && is_concat
}
fn is_template_context(node_kind: &str, node_text: &str, language: Language) -> bool {
match language {
Language::JavaScript | Language::TypeScript => node_kind == "template_string",
Language::Python => node_text.contains("{{") || node_text.contains("{%"),
Language::Ruby => node_text.contains("<%"),
_ => false,
}
}
fn context_from_text_patterns(node: &Node, content: &str, _language: Language) -> SinkContext {
let start = node.start_byte().saturating_sub(200);
let end = (node.end_byte() + 200).min(content.len());
let surrounding = &content[start..end];
let lower = surrounding.to_lowercase();
if lower.contains("select ") || lower.contains("insert into") {
return SinkContext::Sql;
}
if lower.contains("innerhtml") {
return SinkContext::HtmlRaw;
}
if lower.contains("redirect") {
return SinkContext::Url;
}
// Detect shell invocation in surrounding context
if lower.contains("sh -c") || lower.contains("cmd /c") || lower.contains("bash -c") {
return SinkContext::CommandShell;
}
if lower.contains("spawn(") || lower.contains("system(") {
return SinkContext::Command;
}
SinkContext::Unknown
}
/// Get sanitizer patterns for a context
pub fn recommended_sanitizers(context: SinkContext, language: Language) -> Vec<&'static str> {
match context {
SinkContext::HtmlText => match language {
Language::JavaScript | Language::TypeScript => vec!["textContent", "createTextNode"],
Language::Python => vec!["html.escape", "bleach.clean"],
Language::Java => vec!["StringEscapeUtils.escapeHtml"],
Language::Php => vec!["htmlspecialchars", "htmlentities"],
Language::CSharp => vec!["HtmlEncoder.Encode"],
_ => vec!["escape", "encode"],
},
SinkContext::HtmlRaw => vec!["textContent", "DOMPurify.sanitize"],
SinkContext::Url => vec!["URL validation", "encodeURIComponent"],
SinkContext::Sql => vec!["parameterized queries", "prepared statements"],
SinkContext::Command | SinkContext::CommandShell => {
vec!["argument arrays", "avoid shell invocation"]
}
SinkContext::CommandExecArgs => vec!["validate args", "allowlist flags"],
SinkContext::CommandBinaryTaint => vec!["allowlist binaries", "fixed command map"],
SinkContext::JavaScript => vec!["JSON.stringify", "data attributes"],
SinkContext::Template => vec!["auto-escaping", "|escape filter"],
SinkContext::HtmlAttribute => vec!["attribute encoding"],
SinkContext::FilePath => vec!["canonicalize", "base directory check", "reject '..'"],
SinkContext::Unknown => vec![],
}
}
/// Get fix recommendation with command subtype awareness
pub fn fix_recommendation(context: SinkContext, language: Language) -> String {
let sanitizers = recommended_sanitizers(context, language);
let sanitizer_list = sanitizers.join(", ");
match context {
SinkContext::CommandShell => {
let lang_specific = match language {
Language::Rust => "Use Command::new(\"tool\").args([...]) instead of sh -c",
Language::Python => "Use subprocess.run([...], shell=False)",
Language::JavaScript | Language::TypeScript => "Use spawn with args array, not exec string",
Language::Php => "Use escapeshellarg() or avoid shell_exec",
_ => "Avoid shell invocation; use argument arrays",
};
format!("{}. Recommended: {}", lang_specific, sanitizer_list)
}
SinkContext::CommandExecArgs => {
"Ensure binary is constant/allowlisted. Validate args for flags like -c, --;".to_string()
}
SinkContext::CommandBinaryTaint => {
"Never execute user-controlled binary paths. Use allowlist or fixed command map.".to_string()
}
SinkContext::Command => format!(
"Use argument arrays instead of shell strings. Recommended: {}",
sanitizer_list
),
SinkContext::HtmlRaw => format!(
"Avoid innerHTML with user input. Use DOM text APIs or {}",
sanitizer_list
),
SinkContext::Sql => format!(
"Use {} instead of string concatenation",
sanitizer_list
),
SinkContext::Url => format!(
"Validate URL scheme (reject javascript:). Recommended: {}",
sanitizer_list
),
SinkContext::Template => format!(
"Enable auto-escaping. Recommended: {}",
sanitizer_list
),
SinkContext::FilePath => {
let lang_specific = match language {
Language::Rust => "Use Path::canonicalize(), check starts_with(base_dir)",
Language::Python => "Use os.path.realpath(), verify path.startswith(base)",
Language::JavaScript | Language::TypeScript => "Use path.resolve(), check path.startsWith(baseDir)",
Language::Java => "Use Paths.get().normalize().toRealPath(), validate prefix",
Language::Go => "Use filepath.Clean + filepath.Abs, verify HasPrefix",
_ => "Canonicalize path, restrict to base directory, reject '..'",
};
format!("{}. Reject paths containing '..' or absolute paths outside allowed dirs", lang_specific)
}
_ => format!("Apply appropriate sanitization: {}", sanitizer_list),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_shell_detection() {
assert!(matches_shell_invocation("os.system", Language::Python));
assert!(matches_shell_invocation("shell_exec", Language::Php));
assert!(!matches_shell_invocation("spawn", Language::JavaScript));
}
#[test]
fn test_command_subtype_detection() {
// Shell invocation
assert_eq!(
detect_command_subtype("os.system", Language::Python),
Some(SinkContext::CommandShell)
);
// Args-based
assert_eq!(
detect_command_subtype("spawn", Language::JavaScript),
Some(SinkContext::CommandExecArgs)
);
}
#[test]
fn test_fix_recommendations() {
let fix = fix_recommendation(SinkContext::CommandShell, Language::Python);
assert!(fix.contains("shell=False"));
let fix = fix_recommendation(SinkContext::CommandShell, Language::Rust);
assert!(fix.contains("Command::new"));
}
}