#![allow(clippy::if_same_then_else)]
use std::collections::{BTreeMap, HashMap};
use super::const_prop::ConstLattice;
use super::ir::*;
use crate::cfg::{BinOp, Cfg};
use crate::symbol::Lang;
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[allow(dead_code)] pub enum TypeKind {
String,
Int,
Bool,
Object,
Array,
Null,
Unknown,
HttpResponse,
DatabaseConnection,
FileHandle,
Url,
HttpClient,
RequestBuilder,
LocalCollection,
Dto(DtoFields),
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct DtoFields {
pub class_name: String,
pub fields: BTreeMap<String, TypeKind>,
}
impl DtoFields {
pub fn new(class_name: impl Into<String>) -> Self {
Self {
class_name: class_name.into(),
fields: BTreeMap::new(),
}
}
pub fn insert(&mut self, field: impl Into<String>, kind: TypeKind) {
self.fields.insert(field.into(), kind);
}
pub fn get(&self, field: &str) -> Option<&TypeKind> {
self.fields.get(field)
}
}
impl TypeKind {
pub fn label_prefix(&self) -> Option<&'static str> {
match self {
Self::HttpClient => Some("HttpClient"),
Self::HttpResponse => Some("HttpResponse"),
Self::DatabaseConnection => Some("DatabaseConnection"),
Self::FileHandle => Some("FileHandle"),
Self::Url => Some("URL"),
Self::RequestBuilder => Some("RequestBuilder"),
_ => None,
}
}
pub fn container_name(&self) -> Option<String> {
if let Some(prefix) = self.label_prefix() {
return Some(prefix.to_string());
}
if let Self::Dto(d) = self {
return Some(d.class_name.clone());
}
None
}
pub fn as_dto(&self) -> Option<&DtoFields> {
match self {
Self::Dto(d) => Some(d),
_ => None,
}
}
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct TypeFact {
pub kind: TypeKind,
pub nullable: bool,
}
impl TypeFact {
fn unknown() -> Self {
TypeFact {
kind: TypeKind::Unknown,
nullable: false,
}
}
fn from_kind(kind: TypeKind) -> Self {
let nullable = matches!(kind, TypeKind::Null);
TypeFact { kind, nullable }
}
fn meet(&self, other: &Self) -> Self {
let nullable = self.nullable || other.nullable;
let kind = if self.kind == other.kind {
self.kind.clone()
} else {
TypeKind::Unknown
};
TypeFact { kind, nullable }
}
pub(crate) fn from_dto_field(receiver: &TypeKind, field: &str) -> Option<Self> {
let dto = receiver.as_dto()?;
let kind = dto.get(field)?.clone();
Some(Self::from_kind(kind))
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct TypeFactResult {
pub facts: HashMap<SsaValue, TypeFact>,
}
impl TypeFactResult {
pub fn is_int(&self, v: SsaValue) -> bool {
self.facts
.get(&v)
.is_some_and(|f| matches!(f.kind, TypeKind::Int))
}
pub fn get_type(&self, v: SsaValue) -> Option<&TypeKind> {
self.facts.get(&v).map(|f| &f.kind)
}
pub fn is_type(&self, v: SsaValue, kind: &TypeKind) -> bool {
self.facts.get(&v).is_some_and(|f| f.kind == *kind)
}
}
pub fn is_type_safe_for_sink(
values: &[SsaValue],
sink_caps: crate::labels::Cap,
type_facts: &TypeFactResult,
) -> bool {
use crate::labels::Cap;
let type_suppressible = Cap::SQL_QUERY
| Cap::FILE_IO
| Cap::SHELL_ESCAPE
| Cap::HTML_ESCAPE
| Cap::SSRF
| Cap::DATA_EXFIL;
if !sink_caps.intersects(type_suppressible) {
return false;
}
if values.is_empty() {
return false;
}
values.iter().all(|v| {
let Some(kind) = type_facts.get_type(*v) else {
return false;
};
matches!(kind, TypeKind::Int | TypeKind::Bool)
})
}
pub(crate) fn constructor_type(lang: Lang, callee: &str) -> Option<TypeKind> {
let after_colons = callee.rsplit("::").next().unwrap_or(callee);
let suffix = after_colons.rsplit('.').next().unwrap_or(after_colons);
match lang {
Lang::Java => match suffix {
"URL" | "URI" => Some(TypeKind::Url),
"newHttpClient" | "newBuilder" if callee.contains("HttpClient") => {
Some(TypeKind::HttpClient)
}
"createDefault" | "custom" if callee.contains("HttpClient") => {
Some(TypeKind::HttpClient)
}
"OkHttpClient" | "WebClient" | "RestTemplate" => Some(TypeKind::HttpClient),
"getConnection" => Some(TypeKind::DatabaseConnection),
"MongoClient" => Some(TypeKind::DatabaseConnection),
"createStatement" | "prepareCall" => Some(TypeKind::DatabaseConnection),
"FileInputStream" | "FileOutputStream" | "FileReader" | "FileWriter"
| "BufferedReader" | "BufferedWriter" => Some(TypeKind::FileHandle),
"getWriter" | "getOutputStream" => Some(TypeKind::HttpResponse),
_ => None,
},
Lang::JavaScript | Lang::TypeScript => match suffix {
"URL" => Some(TypeKind::Url),
"Request" | "XMLHttpRequest" => Some(TypeKind::HttpClient),
"Map" | "Set" | "WeakMap" | "WeakSet" | "Array" => Some(TypeKind::LocalCollection),
_ => None,
},
Lang::Python => {
if callee.starts_with("requests.")
|| callee == "urlopen"
|| callee == "aiohttp.ClientSession"
|| callee.starts_with("httpx.")
|| callee == "urllib3.PoolManager"
{
Some(TypeKind::HttpClient)
} else if suffix == "connect"
&& (callee.contains("sqlite3")
|| callee.contains("psycopg2")
|| callee.contains("mysql"))
{
Some(TypeKind::DatabaseConnection)
} else if suffix == "open" && !callee.contains('.') {
Some(TypeKind::FileHandle)
} else {
None
}
}
Lang::Go => {
if callee.contains("http.") && matches!(suffix, "NewRequest" | "Get" | "Post") {
Some(TypeKind::HttpClient)
} else if callee.contains("sql.") && suffix == "Open" {
Some(TypeKind::DatabaseConnection)
} else if callee.contains("os.") && matches!(suffix, "Open" | "Create" | "OpenFile") {
Some(TypeKind::FileHandle)
} else if callee.contains("url.") && suffix == "Parse" {
Some(TypeKind::Url)
} else {
None
}
}
Lang::Php => match suffix {
"PDO" | "mysqli" => Some(TypeKind::DatabaseConnection),
"curl_init" => Some(TypeKind::HttpClient),
"fopen" => Some(TypeKind::FileHandle),
"SplFileObject" => Some(TypeKind::FileHandle),
_ => None,
},
Lang::C => match suffix {
"fopen" => Some(TypeKind::FileHandle),
"curl_easy_init" => Some(TypeKind::HttpClient),
"mysql_real_connect" | "PQconnectdb" => Some(TypeKind::DatabaseConnection),
_ => None,
},
Lang::Cpp => match suffix {
"fopen" | "ifstream" | "ofstream" | "fstream" => Some(TypeKind::FileHandle),
"curl_easy_init" => Some(TypeKind::HttpClient),
"mysql_real_connect" | "PQconnectdb" => Some(TypeKind::DatabaseConnection),
_ => None,
},
Lang::Rust => {
let base = peel_identity_suffix(callee);
let base = base.as_str();
if base.ends_with("reqwest::Client::new") || base.ends_with("reqwest::get") {
Some(TypeKind::HttpClient)
} else if base.contains("HttpResponse::") || base.ends_with("Response::builder") {
Some(TypeKind::HttpResponse)
} else if base.ends_with("File::open") || base.ends_with("File::create") {
Some(TypeKind::FileHandle)
} else if base.ends_with("Url::parse") {
Some(TypeKind::Url)
} else if base.ends_with("rusqlite::Connection::open")
|| base.ends_with("Connection::open")
|| base.ends_with("postgres::Client::connect")
|| base.ends_with("sqlx::PgPool::connect")
|| base.ends_with("sqlx::SqlitePool::connect")
|| base.ends_with("sqlx::MySqlPool::connect")
{
Some(TypeKind::DatabaseConnection)
} else if base.ends_with("diesel::PgConnection::establish")
|| base.ends_with("diesel::SqliteConnection::establish")
|| base.ends_with("PgConnection::establish")
|| base.ends_with("SqliteConnection::establish")
{
Some(TypeKind::DatabaseConnection)
} else if is_rust_local_collection_constructor(base) {
Some(TypeKind::LocalCollection)
} else if is_rust_request_builder_constructor(base) {
Some(TypeKind::RequestBuilder)
} else {
None
}
}
Lang::Ruby => {
if callee.contains("Net::HTTP") || after_colons.starts_with("HTTParty") {
Some(TypeKind::HttpClient)
} else if after_colons.starts_with("URI") && matches!(suffix, "parse" | "URI") {
Some(TypeKind::Url)
} else if after_colons == "PG.connect"
|| (after_colons.starts_with("Sequel") && suffix == "connect")
|| callee.contains("Mysql2")
{
Some(TypeKind::DatabaseConnection)
} else if after_colons.starts_with("File.") && matches!(suffix, "open" | "new") {
Some(TypeKind::FileHandle)
} else {
None
}
}
}
}
pub fn peel_identity_suffix(callee: &str) -> String {
let mut cur = crate::labels::normalize_chained_call_for_classify(callee);
if let Some(p) = cur.find('(') {
cur.truncate(p);
}
while let Some(dot_idx) = cur.rfind('.') {
let tail = &cur[dot_idx + 1..];
if !is_identity_method(tail) {
break;
}
cur.truncate(dot_idx);
}
cur
}
fn is_rust_local_collection_constructor(base: &str) -> bool {
const TYPES: &[&str] = &[
"HashMap",
"HashSet",
"BTreeMap",
"BTreeSet",
"VecDeque",
"BinaryHeap",
"LinkedList",
"Vec",
"IndexMap",
"IndexSet",
"SmallVec",
"FxHashMap",
"FxHashSet",
"DashMap",
"DashSet",
"RoaringBitmap",
"RoaringTreemap",
];
const VERBS: &[&str] = &[
"new",
"with_capacity",
"with_capacity_and_hasher",
"with_hasher",
"from",
"from_iter",
"new_in",
"default",
];
TYPES.iter().any(|ty| {
VERBS
.iter()
.any(|verb| base.ends_with(&format!("{ty}::{verb}")))
})
}
fn is_rust_request_builder_constructor(base: &str) -> bool {
const SURF_VERBS: &[&str] = &[
"post", "get", "put", "delete", "patch", "head", "connect", "trace",
];
if SURF_VERBS
.iter()
.any(|v| base.ends_with(&format!("surf::{v}")))
{
return true;
}
const UREQ_VERBS: &[&str] = &["post", "get", "put", "delete", "patch", "head"];
if UREQ_VERBS
.iter()
.any(|v| base.ends_with(&format!("ureq::{v}")))
{
return true;
}
if base.ends_with("Request::builder") || base.ends_with("hyper::Request::builder") {
return true;
}
const REQWEST_CLIENT_VERBS: &[&str] =
&["post", "get", "put", "delete", "patch", "head", "request"];
if REQWEST_CLIENT_VERBS.iter().any(|v| {
base.ends_with(&format!("Client::new.{v}")) || base.ends_with(&format!("Client::{v}"))
}) {
return true;
}
false
}
pub fn is_identity_method(callee: &str) -> bool {
let suffix = callee.rsplit(['.', ':']).next().unwrap_or(callee);
matches!(
suffix,
"unwrap" | "expect" | "clone" | "to_owned" | "into" | "as_ref" | "as_mut" | "ok" | "await"
)
}
pub fn is_int_producing_callee(callee: &str) -> bool {
let base = peel_identity_suffix(callee);
let suffix = base.rsplit(['.', ':']).next().unwrap_or(&base);
matches!(
suffix,
"parseInt" | "parseFloat" | "Number" | "int" | "float" | "ord" | "parseLong" | "parseDouble" | "parseShort" | "Atoi" | "ParseInt" | "ParseFloat" | "intval" | "floatval" | "to_i" | "to_f" | "parse" )
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum InputValidatorPolarity {
BooleanTrueIsValid,
ErrorReturning,
}
pub fn classify_input_validator_callee(callee: &str) -> Option<InputValidatorPolarity> {
let base = peel_identity_suffix(callee);
let suffix = base.rsplit(['.', ':']).next().unwrap_or(&base);
let lower = suffix.to_ascii_lowercase();
if lower.starts_with("isvalid")
|| lower.starts_with("is_valid")
|| lower.starts_with("issafe")
|| lower.starts_with("is_safe")
|| lower.starts_with("hasvalid")
|| lower.starts_with("has_valid")
{
return Some(InputValidatorPolarity::BooleanTrueIsValid);
}
if lower.starts_with("validate") || lower.starts_with("verify") {
return Some(InputValidatorPolarity::ErrorReturning);
}
None
}
pub fn analyze_types(
body: &SsaBody,
cfg: &Cfg,
consts: &HashMap<SsaValue, ConstLattice>,
lang: Option<Lang>,
) -> TypeFactResult {
analyze_types_with_param_types(body, cfg, consts, lang, &[])
}
pub fn analyze_types_with_param_types(
body: &SsaBody,
cfg: &Cfg,
consts: &HashMap<SsaValue, ConstLattice>,
lang: Option<Lang>,
param_types: &[Option<TypeKind>],
) -> TypeFactResult {
let mut facts: HashMap<SsaValue, TypeFact> = HashMap::new();
for block in &body.blocks {
for inst in block.phis.iter().chain(block.body.iter()) {
if cfg
.node_weight(inst.cfg_node)
.is_some_and(|ni| ni.is_numeric_length_access)
{
facts.insert(inst.value, TypeFact::from_kind(TypeKind::Int));
continue;
}
let fact = match &inst.op {
SsaOp::Const(_) => {
match consts.get(&inst.value) {
Some(ConstLattice::Str(_)) => TypeFact::from_kind(TypeKind::String),
Some(ConstLattice::Int(_)) => TypeFact::from_kind(TypeKind::Int),
Some(ConstLattice::Bool(_)) => TypeFact::from_kind(TypeKind::Bool),
Some(ConstLattice::Null) => TypeFact::from_kind(TypeKind::Null),
_ => TypeFact::unknown(),
}
}
SsaOp::Source => TypeFact::from_kind(TypeKind::String),
SsaOp::Param { index } => {
match param_types.get(*index).and_then(|t| t.clone()) {
Some(tk) => TypeFact::from_kind(tk),
None => TypeFact::unknown(),
}
}
SsaOp::SelfParam => TypeFact::from_kind(TypeKind::Object),
SsaOp::CatchParam => TypeFact::from_kind(TypeKind::Object),
SsaOp::Call { callee, .. } => {
if let Some(ty) = lang.and_then(|l| constructor_type(l, callee)) {
TypeFact::from_kind(ty)
} else if is_int_producing_callee(callee) {
TypeFact::from_kind(TypeKind::Int)
} else {
TypeFact::unknown()
}
}
SsaOp::Nop => TypeFact::unknown(),
SsaOp::Assign(uses) if uses.len() == 1 => {
TypeFact::unknown()
}
SsaOp::Assign(_uses) => {
let bin_op = cfg.node_weight(inst.cfg_node).and_then(|ni| ni.bin_op);
match bin_op {
Some(
BinOp::Sub
| BinOp::Mul
| BinOp::Div
| BinOp::Mod
| BinOp::BitAnd
| BinOp::BitOr
| BinOp::BitXor
| BinOp::LeftShift
| BinOp::RightShift
| BinOp::Eq
| BinOp::NotEq
| BinOp::Lt
| BinOp::LtEq
| BinOp::Gt
| BinOp::GtEq,
) => TypeFact::from_kind(TypeKind::Int),
_ => TypeFact::unknown(),
}
}
SsaOp::Phi(_) => {
TypeFact::unknown()
}
SsaOp::FieldProj { projected_type, .. } => match projected_type {
Some(tk) => TypeFact::from_kind(tk.clone()),
None => TypeFact::unknown(),
},
SsaOp::Undef => TypeFact::unknown(),
};
facts.insert(inst.value, fact);
}
}
for _ in 0..10 {
let mut changed = false;
for block in &body.blocks {
for inst in &block.body {
if let SsaOp::Call {
callee,
receiver: Some(recv),
..
} = &inst.op
{
if !is_identity_method(callee) {
continue;
}
if cfg
.node_weight(inst.cfg_node)
.is_some_and(|ni| ni.is_numeric_length_access)
{
continue;
}
let current_kind = facts
.get(&inst.value)
.map(|f| f.kind.clone())
.unwrap_or(TypeKind::Unknown);
if !matches!(current_kind, TypeKind::Unknown) {
continue;
}
let recv_fact = facts.get(recv).cloned().unwrap_or_else(TypeFact::unknown);
if matches!(recv_fact.kind, TypeKind::Unknown) {
continue;
}
if facts.get(&inst.value) != Some(&recv_fact) {
facts.insert(inst.value, recv_fact);
changed = true;
}
}
}
for inst in &block.body {
let SsaOp::FieldProj {
receiver,
field,
projected_type,
} = &inst.op
else {
continue;
};
if projected_type.is_some() {
continue;
}
let Some(recv_fact) = facts.get(receiver).cloned() else {
continue;
};
let field_name = body.field_name(*field).to_string();
let Some(new_fact) = TypeFact::from_dto_field(&recv_fact.kind, &field_name) else {
continue;
};
if facts.get(&inst.value) != Some(&new_fact) {
facts.insert(inst.value, new_fact);
changed = true;
}
}
for inst in &block.phis {
if let SsaOp::Phi(operands) = &inst.op {
let mut result: Option<TypeFact> = None;
for (_, val) in operands {
let operand_fact =
facts.get(val).cloned().unwrap_or_else(TypeFact::unknown);
result = Some(match result {
None => operand_fact,
Some(acc) => acc.meet(&operand_fact),
});
}
if let Some(new_fact) = result {
let old = facts.get(&inst.value);
if old != Some(&new_fact) {
facts.insert(inst.value, new_fact);
changed = true;
}
}
}
}
for inst in &block.body {
if cfg
.node_weight(inst.cfg_node)
.is_some_and(|ni| ni.is_numeric_length_access)
{
continue;
}
if let SsaOp::Assign(uses) = &inst.op {
if uses.len() == 1 {
let dto_field_fact = cfg
.node_weight(inst.cfg_node)
.and_then(|ni| ni.member_field.as_deref())
.and_then(|field| {
let recv_kind = facts.get(&uses[0])?.kind.clone();
TypeFact::from_dto_field(&recv_kind, field)
});
let new_fact = match dto_field_fact {
Some(f) => f,
None => facts
.get(&uses[0])
.cloned()
.unwrap_or_else(TypeFact::unknown),
};
let old = facts.get(&inst.value);
if old != Some(&new_fact) {
facts.insert(inst.value, new_fact);
changed = true;
}
} else if uses.len() == 2 {
let lhs = facts
.get(&uses[0])
.cloned()
.unwrap_or_else(TypeFact::unknown);
let rhs = facts
.get(&uses[1])
.cloned()
.unwrap_or_else(TypeFact::unknown);
if matches!(lhs.kind, TypeKind::Int) && matches!(rhs.kind, TypeKind::Int) {
let new_fact = TypeFact::from_kind(TypeKind::Int);
if facts.get(&inst.value) != Some(&new_fact) {
facts.insert(inst.value, new_fact);
changed = true;
}
}
}
}
}
}
if !changed {
break;
}
}
TypeFactResult { facts }
}
pub struct TypeHierarchy;
static JAVA_HIERARCHY: &[(&str, &[&str])] = &[
("HttpServletResponse", &["ServletResponse"]),
("HttpServletRequest", &["ServletRequest"]),
("HttpURLConnection", &["URLConnection"]),
("CloseableHttpClient", &["HttpClient"]),
("FileInputStream", &["InputStream"]),
("FileOutputStream", &["OutputStream"]),
("BufferedReader", &["Reader"]),
("BufferedWriter", &["Writer"]),
("PreparedStatement", &["Statement"]),
("ArrayList", &["List", "Collection"]),
("HashMap", &["Map"]),
("StringBuilder", &["CharSequence"]),
("StringBuffer", &["CharSequence"]),
("OkHttpClient", &["HttpClient"]),
("WebClient", &["HttpClient"]),
("RestTemplate", &["HttpClient"]),
("MongoClient", &["DatabaseConnection"]),
("RedisTemplate", &["DatabaseConnection"]),
("JmsTemplate", &["DatabaseConnection"]),
("ResponseEntity", &["HttpResponse"]),
(
"HttpServletRequestWrapper",
&["HttpServletRequest", "ServletRequest"],
),
("PrintWriter", &["Writer"]),
("FileReader", &["Reader"]),
("FileWriter", &["Writer"]),
("InputStreamReader", &["Reader"]),
("OutputStreamWriter", &["Writer"]),
];
impl TypeHierarchy {
pub fn is_subtype_of(sub: &str, super_type: &str) -> bool {
if sub == super_type {
return true;
}
JAVA_HIERARCHY
.iter()
.any(|(s, supers)| *s == sub && supers.contains(&super_type))
}
pub fn resolve_kind(class_name: &str) -> Option<TypeKind> {
crate::constraint::solver::class_name_to_type_kind(class_name).or_else(|| {
for (sub, supers) in JAVA_HIERARCHY.iter() {
if *sub == class_name {
for s in *supers {
if let Some(k) = crate::constraint::solver::class_name_to_type_kind(s) {
return Some(k);
}
}
}
}
None
})
}
}
pub struct GoInterfaceTable;
impl GoInterfaceTable {
pub fn satisfies(kind: &TypeKind, interface: &str) -> bool {
match interface {
"http.ResponseWriter" | "ResponseWriter" => {
matches!(kind, TypeKind::HttpResponse)
}
"io.Writer" | "Writer" => {
matches!(kind, TypeKind::HttpResponse | TypeKind::FileHandle)
}
"io.Reader" | "Reader" => matches!(kind, TypeKind::FileHandle),
"io.ReadCloser" | "ReadCloser" => {
matches!(kind, TypeKind::FileHandle | TypeKind::HttpResponse)
}
"sql.DB" | "sql.Conn" | "sql.Tx" | "DB" => {
matches!(kind, TypeKind::DatabaseConnection)
}
"io.WriteCloser" | "WriteCloser" => {
matches!(kind, TypeKind::HttpResponse | TypeKind::FileHandle)
}
"io.ReadWriteCloser" | "ReadWriteCloser" => {
matches!(kind, TypeKind::HttpResponse | TypeKind::FileHandle)
}
_ => true, }
}
pub fn definitely_not(kind: &TypeKind, interface: &str) -> bool {
match interface {
"http.ResponseWriter" | "ResponseWriter" => matches!(
kind,
TypeKind::Int
| TypeKind::Bool
| TypeKind::String
| TypeKind::FileHandle
| TypeKind::DatabaseConnection
| TypeKind::Url
| TypeKind::HttpClient
),
"io.ReadCloser" | "ReadCloser" => matches!(
kind,
TypeKind::Int
| TypeKind::Bool
| TypeKind::String
| TypeKind::DatabaseConnection
| TypeKind::Url
| TypeKind::HttpClient
),
"sql.DB" | "sql.Conn" | "sql.Tx" | "DB" => matches!(
kind,
TypeKind::Int
| TypeKind::Bool
| TypeKind::String
| TypeKind::HttpResponse
| TypeKind::FileHandle
| TypeKind::HttpClient
| TypeKind::Url
),
"io.WriteCloser" | "WriteCloser" | "io.ReadWriteCloser" | "ReadWriteCloser" => {
matches!(
kind,
TypeKind::Int
| TypeKind::Bool
| TypeKind::String
| TypeKind::DatabaseConnection
| TypeKind::Url
)
}
_ => false, }
}
}
#[cfg(test)]
mod tests {
use super::*;
use petgraph::Graph;
use petgraph::graph::NodeIndex;
use smallvec::SmallVec;
#[test]
fn const_types_inferred() {
let n0 = NodeIndex::new(0);
let n1 = NodeIndex::new(1);
let n2 = NodeIndex::new(2);
let body = SsaBody {
blocks: vec![SsaBlock {
id: BlockId(0),
phis: vec![],
body: vec![
SsaInst {
value: SsaValue(0),
op: SsaOp::Const(Some("42".into())),
cfg_node: n0,
var_name: Some("x".into()),
span: (0, 2),
},
SsaInst {
value: SsaValue(1),
op: SsaOp::Const(Some("\"hello\"".into())),
cfg_node: n1,
var_name: Some("y".into()),
span: (3, 10),
},
SsaInst {
value: SsaValue(2),
op: SsaOp::Source,
cfg_node: n2,
var_name: Some("z".into()),
span: (11, 15),
},
],
terminator: Terminator::Return(None),
preds: SmallVec::new(),
succs: SmallVec::new(),
}],
entry: BlockId(0),
value_defs: vec![
ValueDef {
var_name: Some("x".into()),
cfg_node: n0,
block: BlockId(0),
},
ValueDef {
var_name: Some("y".into()),
cfg_node: n1,
block: BlockId(0),
},
ValueDef {
var_name: Some("z".into()),
cfg_node: n2,
block: BlockId(0),
},
],
cfg_node_map: [(n0, SsaValue(0)), (n1, SsaValue(1)), (n2, SsaValue(2))]
.into_iter()
.collect(),
exception_edges: vec![],
field_interner: crate::ssa::ir::FieldInterner::default(),
field_writes: std::collections::HashMap::new(),
synthetic_externals: std::collections::HashSet::new(),
};
let consts = HashMap::from([
(SsaValue(0), ConstLattice::Int(42)),
(SsaValue(1), ConstLattice::Str("hello".into())),
]);
let cfg: crate::cfg::Cfg = Graph::new();
let result = analyze_types(&body, &cfg, &consts, None);
assert!(result.is_int(SsaValue(0)));
assert_eq!(
result.facts.get(&SsaValue(1)).unwrap().kind,
TypeKind::String
);
assert_eq!(
result.facts.get(&SsaValue(2)).unwrap().kind,
TypeKind::String
); }
#[test]
fn security_type_variants_distinct() {
let http_client = TypeFact::from_kind(TypeKind::HttpClient);
let url = TypeFact::from_kind(TypeKind::Url);
let http_response = TypeFact::from_kind(TypeKind::HttpResponse);
let db_conn = TypeFact::from_kind(TypeKind::DatabaseConnection);
let file_handle = TypeFact::from_kind(TypeKind::FileHandle);
assert_eq!(http_client.meet(&http_client).kind, TypeKind::HttpClient);
assert_eq!(url.meet(&url).kind, TypeKind::Url);
assert_eq!(http_client.meet(&url).kind, TypeKind::Unknown);
assert_eq!(http_response.meet(&db_conn).kind, TypeKind::Unknown);
assert_eq!(file_handle.meet(&http_client).kind, TypeKind::Unknown);
}
#[test]
fn label_prefix_mappings() {
assert_eq!(TypeKind::HttpClient.label_prefix(), Some("HttpClient"));
assert_eq!(TypeKind::HttpResponse.label_prefix(), Some("HttpResponse"));
assert_eq!(TypeKind::Url.label_prefix(), Some("URL"));
assert_eq!(
TypeKind::DatabaseConnection.label_prefix(),
Some("DatabaseConnection")
);
assert_eq!(TypeKind::FileHandle.label_prefix(), Some("FileHandle"));
assert_eq!(TypeKind::String.label_prefix(), None);
assert_eq!(TypeKind::Int.label_prefix(), None);
assert_eq!(TypeKind::Unknown.label_prefix(), None);
}
#[test]
fn constructor_type_inference() {
let n0 = NodeIndex::new(0);
let n1 = NodeIndex::new(1);
let body = SsaBody {
blocks: vec![SsaBlock {
id: BlockId(0),
phis: vec![],
body: vec![
SsaInst {
value: SsaValue(0),
op: SsaOp::Call {
callee: "URL".into(),
callee_text: None,
args: vec![],
receiver: None,
},
cfg_node: n0,
var_name: Some("url".into()),
span: (0, 5),
},
SsaInst {
value: SsaValue(1),
op: SsaOp::Call {
callee: "HttpClient.newHttpClient".into(),
callee_text: None,
args: vec![],
receiver: None,
},
cfg_node: n1,
var_name: Some("client".into()),
span: (6, 20),
},
],
terminator: Terminator::Return(None),
preds: SmallVec::new(),
succs: SmallVec::new(),
}],
entry: BlockId(0),
value_defs: vec![
ValueDef {
var_name: Some("url".into()),
cfg_node: n0,
block: BlockId(0),
},
ValueDef {
var_name: Some("client".into()),
cfg_node: n1,
block: BlockId(0),
},
],
cfg_node_map: [(n0, SsaValue(0)), (n1, SsaValue(1))].into_iter().collect(),
exception_edges: vec![],
field_interner: crate::ssa::ir::FieldInterner::default(),
field_writes: std::collections::HashMap::new(),
synthetic_externals: std::collections::HashSet::new(),
};
let consts = HashMap::new();
let cfg: crate::cfg::Cfg = Graph::new();
let result = analyze_types(&body, &cfg, &consts, Some(Lang::Java));
assert_eq!(result.get_type(SsaValue(0)), Some(&TypeKind::Url));
assert_eq!(result.get_type(SsaValue(1)), Some(&TypeKind::HttpClient));
let result_js = analyze_types(&body, &cfg, &consts, Some(Lang::JavaScript));
assert_eq!(result_js.get_type(SsaValue(0)), Some(&TypeKind::Url));
assert_eq!(result_js.get_type(SsaValue(1)), Some(&TypeKind::Unknown));
}
#[test]
fn get_type_and_is_type() {
let mut facts = HashMap::new();
facts.insert(SsaValue(0), TypeFact::from_kind(TypeKind::HttpClient));
facts.insert(SsaValue(1), TypeFact::from_kind(TypeKind::Int));
let result = TypeFactResult { facts };
assert_eq!(result.get_type(SsaValue(0)), Some(&TypeKind::HttpClient));
assert!(result.is_type(SsaValue(0), &TypeKind::HttpClient));
assert!(!result.is_type(SsaValue(0), &TypeKind::Url));
assert!(result.is_int(SsaValue(1)));
assert_eq!(result.get_type(SsaValue(99)), None);
}
#[test]
fn int_suppresses_every_type_suppressible_cap() {
use crate::labels::Cap;
let mut facts = HashMap::new();
facts.insert(SsaValue(0), TypeFact::from_kind(TypeKind::Int));
let result = TypeFactResult { facts };
for cap in [
Cap::SQL_QUERY,
Cap::FILE_IO,
Cap::SHELL_ESCAPE,
Cap::HTML_ESCAPE,
Cap::SSRF,
Cap::DATA_EXFIL,
] {
assert!(
is_type_safe_for_sink(&[SsaValue(0)], cap, &result),
"Int must suppress {cap:?}",
);
}
assert!(!is_type_safe_for_sink(
&[SsaValue(0)],
Cap::CODE_EXEC,
&result
));
assert!(!is_type_safe_for_sink(
&[SsaValue(0)],
Cap::DESERIALIZE,
&result
));
}
#[test]
fn bool_suppresses_every_type_suppressible_cap() {
use crate::labels::Cap;
let mut facts = HashMap::new();
facts.insert(SsaValue(0), TypeFact::from_kind(TypeKind::Bool));
let result = TypeFactResult { facts };
for cap in [
Cap::SQL_QUERY,
Cap::FILE_IO,
Cap::SHELL_ESCAPE,
Cap::HTML_ESCAPE,
Cap::SSRF,
Cap::DATA_EXFIL,
] {
assert!(
is_type_safe_for_sink(&[SsaValue(0)], cap, &result),
"Bool must suppress {cap:?}",
);
}
}
#[test]
fn string_does_not_trigger_sink_suppression() {
use crate::labels::Cap;
let mut facts = HashMap::new();
facts.insert(SsaValue(0), TypeFact::from_kind(TypeKind::String));
let result = TypeFactResult { facts };
assert!(!is_type_safe_for_sink(
&[SsaValue(0)],
Cap::SQL_QUERY,
&result
));
assert!(!is_type_safe_for_sink(&[SsaValue(0)], Cap::SSRF, &result));
assert!(!is_type_safe_for_sink(
&[SsaValue(0)],
Cap::SHELL_ESCAPE,
&result
));
}
#[test]
fn type_kind_cap_suppression_matrix() {
use crate::labels::Cap;
let caps = [
("SQL_QUERY", Cap::SQL_QUERY),
("FILE_IO", Cap::FILE_IO),
("SHELL_ESCAPE", Cap::SHELL_ESCAPE),
("HTML_ESCAPE", Cap::HTML_ESCAPE),
("SSRF", Cap::SSRF),
("DATA_EXFIL", Cap::DATA_EXFIL),
("CODE_EXEC", Cap::CODE_EXEC),
("DESERIALIZE", Cap::DESERIALIZE),
];
let rows: &[(&str, TypeKind, [bool; 8])] = &[
(
"Int",
TypeKind::Int,
[true, true, true, true, true, true, false, false],
),
(
"Bool",
TypeKind::Bool,
[true, true, true, true, true, true, false, false],
),
(
"String",
TypeKind::String,
[false, false, false, false, false, false, false, false],
),
(
"Url",
TypeKind::Url,
[false, false, false, false, false, false, false, false],
),
(
"Object",
TypeKind::Object,
[false, false, false, false, false, false, false, false],
),
(
"Unknown",
TypeKind::Unknown,
[false, false, false, false, false, false, false, false],
),
];
for (kind_name, kind, expected) in rows {
let mut facts = HashMap::new();
facts.insert(SsaValue(0), TypeFact::from_kind(kind.clone()));
let result = TypeFactResult { facts };
for (i, (cap_name, cap)) in caps.iter().enumerate() {
let got = is_type_safe_for_sink(&[SsaValue(0)], *cap, &result);
assert_eq!(
got, expected[i],
"matrix mismatch for ({kind_name}, {cap_name}): expected {}, got {got}",
expected[i]
);
}
}
}
#[test]
fn empty_values_never_suppress() {
use crate::labels::Cap;
let mut facts = HashMap::new();
facts.insert(SsaValue(0), TypeFact::from_kind(TypeKind::Int));
let result = TypeFactResult { facts };
for cap in [
Cap::SQL_QUERY,
Cap::FILE_IO,
Cap::SHELL_ESCAPE,
Cap::HTML_ESCAPE,
Cap::SSRF,
Cap::DATA_EXFIL,
Cap::CODE_EXEC,
Cap::DESERIALIZE,
] {
assert!(
!is_type_safe_for_sink(&[], cap, &result),
"empty values must never suppress {cap:?}",
);
}
}
#[test]
fn caps_without_type_suppressible_bits_never_fire() {
use crate::labels::Cap;
let mut facts = HashMap::new();
facts.insert(SsaValue(0), TypeFact::from_kind(TypeKind::Int));
let result = TypeFactResult { facts };
for cap in [
Cap::CODE_EXEC,
Cap::DESERIALIZE,
Cap::CRYPTO,
Cap::URL_ENCODE,
] {
assert!(
!is_type_safe_for_sink(&[SsaValue(0)], cap, &result),
"Int must NOT suppress non-type-suppressible {cap:?}",
);
}
}
#[test]
fn mixed_type_operands_do_not_suppress() {
use crate::labels::Cap;
let mut facts = HashMap::new();
facts.insert(SsaValue(0), TypeFact::from_kind(TypeKind::Int));
facts.insert(SsaValue(1), TypeFact::from_kind(TypeKind::String));
let result = TypeFactResult { facts };
assert!(!is_type_safe_for_sink(
&[SsaValue(0), SsaValue(1)],
Cap::SQL_QUERY,
&result
));
}
#[test]
fn param_types_seed_param_value_facts() {
use crate::cfg::Cfg;
let n0 = NodeIndex::new(0);
let n1 = NodeIndex::new(1);
let body = SsaBody {
blocks: vec![SsaBlock {
id: BlockId(0),
phis: vec![],
body: vec![
SsaInst {
value: SsaValue(0),
op: SsaOp::Param { index: 0 },
cfg_node: n0,
var_name: Some("user_id".into()),
span: (0, 7),
},
SsaInst {
value: SsaValue(1),
op: SsaOp::Param { index: 99 },
cfg_node: n1,
var_name: Some("oob".into()),
span: (8, 11),
},
],
terminator: Terminator::Return(None),
preds: SmallVec::new(),
succs: SmallVec::new(),
}],
entry: BlockId(0),
value_defs: vec![
ValueDef {
var_name: Some("user_id".into()),
cfg_node: n0,
block: BlockId(0),
},
ValueDef {
var_name: Some("oob".into()),
cfg_node: n1,
block: BlockId(0),
},
],
cfg_node_map: [(n0, SsaValue(0)), (n1, SsaValue(1))].into_iter().collect(),
exception_edges: vec![],
field_interner: crate::ssa::ir::FieldInterner::default(),
field_writes: std::collections::HashMap::new(),
synthetic_externals: std::collections::HashSet::new(),
};
let consts = HashMap::new();
let cfg: Cfg = petgraph::Graph::new();
let param_types = vec![Some(TypeKind::Int)];
let result =
analyze_types_with_param_types(&body, &cfg, &consts, Some(Lang::Java), ¶m_types);
assert_eq!(result.get_type(SsaValue(0)), Some(&TypeKind::Int));
assert_eq!(result.get_type(SsaValue(1)), Some(&TypeKind::Unknown));
let result2 = analyze_types(&body, &cfg, &consts, Some(Lang::Java));
assert_eq!(result2.get_type(SsaValue(0)), Some(&TypeKind::Unknown));
}
#[test]
fn hierarchy_http_servlet_response_is_servlet_response() {
assert!(TypeHierarchy::is_subtype_of(
"HttpServletResponse",
"ServletResponse"
));
}
#[test]
fn hierarchy_string_is_not_servlet_response() {
assert!(!TypeHierarchy::is_subtype_of("String", "ServletResponse"));
}
#[test]
fn hierarchy_identity_subtype() {
assert!(TypeHierarchy::is_subtype_of(
"HttpServletResponse",
"HttpServletResponse"
));
}
#[test]
fn resolve_closeable_http_client() {
assert_eq!(
TypeHierarchy::resolve_kind("CloseableHttpClient"),
Some(TypeKind::HttpClient)
);
}
#[test]
fn resolve_string_builder() {
assert_eq!(
TypeHierarchy::resolve_kind("StringBuilder"),
Some(TypeKind::String)
);
}
#[test]
fn go_file_handle_definitely_not_response_writer() {
assert!(GoInterfaceTable::definitely_not(
&TypeKind::FileHandle,
"http.ResponseWriter"
));
}
#[test]
fn go_http_response_not_definitely_not_response_writer() {
assert!(!GoInterfaceTable::definitely_not(
&TypeKind::HttpResponse,
"http.ResponseWriter"
));
}
#[test]
fn go_http_response_satisfies_response_writer() {
assert!(GoInterfaceTable::satisfies(
&TypeKind::HttpResponse,
"http.ResponseWriter"
));
}
#[test]
fn go_file_handle_does_not_satisfy_response_writer() {
assert!(!GoInterfaceTable::satisfies(
&TypeKind::FileHandle,
"http.ResponseWriter"
));
}
#[test]
fn go_http_response_satisfies_io_writer() {
assert!(GoInterfaceTable::satisfies(
&TypeKind::HttpResponse,
"io.Writer"
));
}
#[test]
fn constructor_type_php() {
assert_eq!(
constructor_type(Lang::Php, "PDO"),
Some(TypeKind::DatabaseConnection)
);
assert_eq!(
constructor_type(Lang::Php, "mysqli"),
Some(TypeKind::DatabaseConnection)
);
assert_eq!(
constructor_type(Lang::Php, "curl_init"),
Some(TypeKind::HttpClient)
);
assert_eq!(
constructor_type(Lang::Php, "fopen"),
Some(TypeKind::FileHandle)
);
assert_eq!(
constructor_type(Lang::Php, "SplFileObject"),
Some(TypeKind::FileHandle)
);
assert_eq!(constructor_type(Lang::Php, "array_map"), None);
}
#[test]
fn constructor_type_c() {
assert_eq!(
constructor_type(Lang::C, "fopen"),
Some(TypeKind::FileHandle)
);
assert_eq!(
constructor_type(Lang::C, "curl_easy_init"),
Some(TypeKind::HttpClient)
);
assert_eq!(
constructor_type(Lang::C, "mysql_real_connect"),
Some(TypeKind::DatabaseConnection)
);
assert_eq!(
constructor_type(Lang::C, "PQconnectdb"),
Some(TypeKind::DatabaseConnection)
);
assert_eq!(constructor_type(Lang::C, "printf"), None);
}
#[test]
fn constructor_type_cpp() {
assert_eq!(
constructor_type(Lang::Cpp, "fopen"),
Some(TypeKind::FileHandle)
);
assert_eq!(
constructor_type(Lang::Cpp, "curl_easy_init"),
Some(TypeKind::HttpClient)
);
assert_eq!(
constructor_type(Lang::Cpp, "ifstream"),
Some(TypeKind::FileHandle)
);
assert_eq!(
constructor_type(Lang::Cpp, "ofstream"),
Some(TypeKind::FileHandle)
);
assert_eq!(
constructor_type(Lang::Cpp, "fstream"),
Some(TypeKind::FileHandle)
);
assert_eq!(constructor_type(Lang::Cpp, "printf"), None);
}
#[test]
fn constructor_type_javascript_typescript_local_collections() {
for lang in [Lang::JavaScript, Lang::TypeScript] {
assert_eq!(
constructor_type(lang, "Map"),
Some(TypeKind::LocalCollection)
);
assert_eq!(
constructor_type(lang, "Set"),
Some(TypeKind::LocalCollection)
);
assert_eq!(
constructor_type(lang, "WeakMap"),
Some(TypeKind::LocalCollection)
);
assert_eq!(
constructor_type(lang, "WeakSet"),
Some(TypeKind::LocalCollection)
);
assert_eq!(
constructor_type(lang, "Array"),
Some(TypeKind::LocalCollection)
);
assert_eq!(constructor_type(lang, "URL"), Some(TypeKind::Url));
assert_eq!(
constructor_type(lang, "XMLHttpRequest"),
Some(TypeKind::HttpClient)
);
assert_eq!(constructor_type(lang, "Object"), None);
assert_eq!(constructor_type(lang, "Promise"), None);
assert_eq!(constructor_type(lang, "Foo"), None);
}
}
#[test]
fn constructor_type_ruby() {
assert_eq!(
constructor_type(Lang::Ruby, "Net::HTTP.new"),
Some(TypeKind::HttpClient)
);
assert_eq!(
constructor_type(Lang::Ruby, "Net::HTTP.get"),
Some(TypeKind::HttpClient)
);
assert_eq!(
constructor_type(Lang::Ruby, "HTTParty.get"),
Some(TypeKind::HttpClient)
);
assert_eq!(
constructor_type(Lang::Ruby, "HTTParty.post"),
Some(TypeKind::HttpClient)
);
assert_eq!(
constructor_type(Lang::Ruby, "URI.parse"),
Some(TypeKind::Url)
);
assert_eq!(
constructor_type(Lang::Ruby, "PG.connect"),
Some(TypeKind::DatabaseConnection)
);
assert_eq!(
constructor_type(Lang::Ruby, "Sequel.connect"),
Some(TypeKind::DatabaseConnection)
);
assert_eq!(
constructor_type(Lang::Ruby, "Mysql2::Client.new"),
Some(TypeKind::DatabaseConnection)
);
assert_eq!(
constructor_type(Lang::Ruby, "File.open"),
Some(TypeKind::FileHandle)
);
assert_eq!(
constructor_type(Lang::Ruby, "File.new"),
Some(TypeKind::FileHandle)
);
assert_eq!(constructor_type(Lang::Ruby, "puts"), None);
assert_eq!(constructor_type(Lang::Ruby, "Array.new"), None);
}
#[test]
fn constructor_type_rust_exact() {
assert_eq!(
constructor_type(Lang::Rust, "reqwest::Client::new"),
Some(TypeKind::HttpClient)
);
assert_eq!(
constructor_type(Lang::Rust, "reqwest::get"),
Some(TypeKind::HttpClient)
);
assert_eq!(
constructor_type(Lang::Rust, "File::open"),
Some(TypeKind::FileHandle)
);
assert_eq!(
constructor_type(Lang::Rust, "File::create"),
Some(TypeKind::FileHandle)
);
assert_eq!(
constructor_type(Lang::Rust, "std::fs::File::open"),
Some(TypeKind::FileHandle)
);
assert_eq!(
constructor_type(Lang::Rust, "Url::parse"),
Some(TypeKind::Url)
);
assert_eq!(
constructor_type(Lang::Rust, "rusqlite::Connection::open"),
Some(TypeKind::DatabaseConnection)
);
assert_eq!(
constructor_type(Lang::Rust, "diesel::PgConnection::establish"),
Some(TypeKind::DatabaseConnection)
);
assert_eq!(
constructor_type(Lang::Rust, "diesel::SqliteConnection::establish"),
Some(TypeKind::DatabaseConnection)
);
assert_eq!(
constructor_type(Lang::Rust, "Connection::open"),
Some(TypeKind::DatabaseConnection)
);
assert_eq!(
constructor_type(Lang::Rust, "Connection::open(\"app.db\").unwrap"),
Some(TypeKind::DatabaseConnection)
);
assert_eq!(constructor_type(Lang::Rust, "println!"), None);
}
#[test]
fn constructor_type_java_expanded() {
assert_eq!(
constructor_type(Lang::Java, "OkHttpClient"),
Some(TypeKind::HttpClient)
);
assert_eq!(
constructor_type(Lang::Java, "WebClient"),
Some(TypeKind::HttpClient)
);
assert_eq!(
constructor_type(Lang::Java, "RestTemplate"),
Some(TypeKind::HttpClient)
);
assert_eq!(
constructor_type(Lang::Java, "MongoClient"),
Some(TypeKind::DatabaseConnection)
);
}
#[test]
fn constructor_type_go_url() {
assert_eq!(constructor_type(Lang::Go, "url.Parse"), Some(TypeKind::Url));
}
#[test]
fn constructor_type_python_aiohttp() {
assert_eq!(
constructor_type(Lang::Python, "aiohttp.ClientSession"),
Some(TypeKind::HttpClient)
);
assert_eq!(
constructor_type(Lang::Python, "httpx.Client"),
Some(TypeKind::HttpClient)
);
assert_eq!(
constructor_type(Lang::Python, "urllib3.PoolManager"),
Some(TypeKind::HttpClient)
);
}
#[test]
fn java_hierarchy_expansion() {
assert!(TypeHierarchy::is_subtype_of("OkHttpClient", "HttpClient"));
assert!(TypeHierarchy::is_subtype_of("WebClient", "HttpClient"));
assert!(TypeHierarchy::is_subtype_of("RestTemplate", "HttpClient"));
assert!(TypeHierarchy::is_subtype_of(
"MongoClient",
"DatabaseConnection"
));
assert!(TypeHierarchy::is_subtype_of(
"RedisTemplate",
"DatabaseConnection"
));
assert!(TypeHierarchy::is_subtype_of(
"JmsTemplate",
"DatabaseConnection"
));
assert_eq!(
TypeHierarchy::resolve_kind("OkHttpClient"),
Some(TypeKind::HttpClient)
);
assert_eq!(
TypeHierarchy::resolve_kind("RestTemplate"),
Some(TypeKind::HttpClient)
);
assert_eq!(
TypeHierarchy::resolve_kind("MongoClient"),
Some(TypeKind::DatabaseConnection)
);
}
#[test]
fn go_interface_read_closer() {
assert!(GoInterfaceTable::satisfies(
&TypeKind::FileHandle,
"io.ReadCloser"
));
assert!(GoInterfaceTable::satisfies(
&TypeKind::HttpResponse,
"io.ReadCloser"
));
assert!(!GoInterfaceTable::satisfies(
&TypeKind::Int,
"io.ReadCloser"
));
assert!(GoInterfaceTable::definitely_not(
&TypeKind::Int,
"io.ReadCloser"
));
assert!(GoInterfaceTable::definitely_not(
&TypeKind::DatabaseConnection,
"io.ReadCloser"
));
assert!(GoInterfaceTable::definitely_not(
&TypeKind::HttpClient,
"io.ReadCloser"
));
assert!(!GoInterfaceTable::definitely_not(
&TypeKind::FileHandle,
"io.ReadCloser"
));
}
#[test]
fn go_http_client_definitely_not_response_writer() {
assert!(GoInterfaceTable::definitely_not(
&TypeKind::HttpClient,
"http.ResponseWriter"
));
}
#[test]
fn java_hierarchy_resolve_response_entity() {
assert_eq!(
TypeHierarchy::resolve_kind("ResponseEntity"),
Some(TypeKind::HttpResponse)
);
}
#[test]
fn java_hierarchy_resolve_print_writer() {
assert_eq!(
TypeHierarchy::resolve_kind("PrintWriter"),
Some(TypeKind::FileHandle)
);
assert!(TypeHierarchy::is_subtype_of("PrintWriter", "Writer"));
}
#[test]
fn java_hierarchy_io_subtypes() {
assert!(TypeHierarchy::is_subtype_of("FileReader", "Reader"));
assert!(TypeHierarchy::is_subtype_of("FileWriter", "Writer"));
assert!(TypeHierarchy::is_subtype_of("InputStreamReader", "Reader"));
assert!(TypeHierarchy::is_subtype_of("OutputStreamWriter", "Writer"));
assert!(TypeHierarchy::is_subtype_of(
"HttpServletRequestWrapper",
"HttpServletRequest"
));
assert!(TypeHierarchy::is_subtype_of(
"HttpServletRequestWrapper",
"ServletRequest"
));
}
#[test]
fn go_interface_sql_db_definitely_not_response() {
assert!(GoInterfaceTable::definitely_not(
&TypeKind::DatabaseConnection,
"http.ResponseWriter"
));
assert!(GoInterfaceTable::definitely_not(
&TypeKind::HttpResponse,
"sql.DB"
));
assert!(GoInterfaceTable::definitely_not(
&TypeKind::FileHandle,
"sql.DB"
));
assert!(GoInterfaceTable::definitely_not(
&TypeKind::HttpClient,
"sql.DB"
));
}
#[test]
fn go_interface_sql_db_satisfies() {
assert!(GoInterfaceTable::satisfies(
&TypeKind::DatabaseConnection,
"sql.DB"
));
assert!(GoInterfaceTable::satisfies(
&TypeKind::DatabaseConnection,
"sql.Conn"
));
assert!(GoInterfaceTable::satisfies(
&TypeKind::DatabaseConnection,
"sql.Tx"
));
assert!(!GoInterfaceTable::satisfies(
&TypeKind::HttpResponse,
"sql.DB"
));
assert!(!GoInterfaceTable::satisfies(&TypeKind::Int, "sql.DB"));
}
#[test]
fn go_interface_write_closer() {
assert!(GoInterfaceTable::satisfies(
&TypeKind::HttpResponse,
"io.WriteCloser"
));
assert!(GoInterfaceTable::satisfies(
&TypeKind::FileHandle,
"io.WriteCloser"
));
assert!(!GoInterfaceTable::satisfies(
&TypeKind::Int,
"io.WriteCloser"
));
assert!(!GoInterfaceTable::satisfies(
&TypeKind::DatabaseConnection,
"io.WriteCloser"
));
assert!(GoInterfaceTable::definitely_not(
&TypeKind::DatabaseConnection,
"io.WriteCloser"
));
assert!(!GoInterfaceTable::definitely_not(
&TypeKind::FileHandle,
"io.WriteCloser"
));
}
#[test]
fn colon_normalization_in_constructor_type() {
assert_eq!(constructor_type(Lang::Java, "URL"), Some(TypeKind::Url));
assert_eq!(
constructor_type(Lang::JavaScript, "URL"),
Some(TypeKind::Url)
);
assert_eq!(
constructor_type(Lang::Python, "requests.get"),
Some(TypeKind::HttpClient)
);
assert_eq!(
constructor_type(Lang::Go, "http.Get"),
Some(TypeKind::HttpClient)
);
}
#[test]
fn dto_field_lookup_returns_field_type_kind() {
let mut dto = DtoFields::new("CreateUser");
dto.insert("age", TypeKind::Int);
dto.insert("email", TypeKind::String);
let recv = TypeKind::Dto(dto);
let age = TypeFact::from_dto_field(&recv, "age").expect("age field present");
assert_eq!(age.kind, TypeKind::Int);
let email = TypeFact::from_dto_field(&recv, "email").expect("email field present");
assert_eq!(email.kind, TypeKind::String);
assert!(TypeFact::from_dto_field(&recv, "missing").is_none());
}
#[test]
fn dto_field_lookup_on_non_dto_returns_none() {
for k in [
TypeKind::Int,
TypeKind::String,
TypeKind::Object,
TypeKind::Unknown,
TypeKind::HttpClient,
] {
assert!(
TypeFact::from_dto_field(&k, "any_field").is_none(),
"non-DTO {k:?} must not produce a field fact",
);
}
}
#[test]
fn dto_field_lookup_supports_nested_dto() {
let mut inner = DtoFields::new("Address");
inner.insert("zip", TypeKind::String);
let mut outer = DtoFields::new("CreateUser");
outer.insert("address", TypeKind::Dto(inner.clone()));
outer.insert("age", TypeKind::Int);
let recv = TypeKind::Dto(outer);
let addr = TypeFact::from_dto_field(&recv, "address").expect("address present");
assert_eq!(addr.kind, TypeKind::Dto(inner));
}
#[test]
fn empty_dto_never_resolves_fields() {
let recv = TypeKind::Dto(DtoFields::new("EmptyDto"));
assert!(TypeFact::from_dto_field(&recv, "anything").is_none());
}
#[test]
fn dto_int_field_suppresses_sql_query_via_matrix() {
use crate::labels::Cap;
let mut dto = DtoFields::new("CreateUser");
dto.insert("age", TypeKind::Int);
let field = TypeFact::from_dto_field(&TypeKind::Dto(dto), "age").unwrap();
let mut facts = HashMap::new();
facts.insert(SsaValue(0), field);
let result = TypeFactResult { facts };
assert!(is_type_safe_for_sink(
&[SsaValue(0)],
Cap::SQL_QUERY,
&result
));
assert!(!is_type_safe_for_sink(
&[SsaValue(0)],
Cap::CODE_EXEC,
&result
));
}
}