use crate::abstract_interp::{self, AbstractState};
use crate::cfg::BodyId;
use crate::constraint;
use crate::pointer::LocId;
use crate::ssa::heap::HeapState;
use crate::ssa::ir::{FieldId, SsaValue};
use crate::state::lattice::Lattice;
use crate::state::symbol::SymbolId;
use crate::taint::domain::{PredicateSummary, SmallBitSet, TaintOrigin, VarTaint};
use smallvec::SmallVec;
use std::cell::RefCell;
use std::collections::HashMap;
pub(super) const WORKLIST_SAFETY_CAP: usize = 100_000;
static WORKLIST_CAP_OVERRIDE: std::sync::atomic::AtomicUsize =
std::sync::atomic::AtomicUsize::new(0);
pub(super) static MAX_WORKLIST_ITERATIONS: std::sync::atomic::AtomicUsize =
std::sync::atomic::AtomicUsize::new(0);
pub(super) static WORKLIST_CAP_HITS: std::sync::atomic::AtomicUsize =
std::sync::atomic::AtomicUsize::new(0);
#[doc(hidden)]
pub fn set_worklist_cap_override(cap: usize) {
WORKLIST_CAP_OVERRIDE.store(cap, std::sync::atomic::Ordering::Relaxed);
}
pub(super) fn effective_worklist_cap() -> usize {
let o = WORKLIST_CAP_OVERRIDE.load(std::sync::atomic::Ordering::Relaxed);
if o == 0 { WORKLIST_SAFETY_CAP } else { o }
}
pub fn max_worklist_iterations() -> usize {
MAX_WORKLIST_ITERATIONS.load(std::sync::atomic::Ordering::Relaxed)
}
pub fn worklist_cap_hit_count() -> usize {
WORKLIST_CAP_HITS.load(std::sync::atomic::Ordering::Relaxed)
}
pub fn reset_worklist_observability() {
MAX_WORKLIST_ITERATIONS.store(0, std::sync::atomic::Ordering::Relaxed);
WORKLIST_CAP_HITS.store(0, std::sync::atomic::Ordering::Relaxed);
}
static MAX_ORIGINS_OVERRIDE: std::sync::atomic::AtomicUsize =
std::sync::atomic::AtomicUsize::new(0);
pub(super) static ORIGINS_TRUNCATION_COUNT: std::sync::atomic::AtomicUsize =
std::sync::atomic::AtomicUsize::new(0);
#[doc(hidden)]
pub fn set_max_origins_override(cap: usize) {
MAX_ORIGINS_OVERRIDE.store(cap, std::sync::atomic::Ordering::Relaxed);
}
pub(super) fn effective_max_origins() -> usize {
let o = MAX_ORIGINS_OVERRIDE.load(std::sync::atomic::Ordering::Relaxed);
if o != 0 {
return o;
}
crate::utils::analysis_options::current().max_origins as usize
}
pub fn origins_truncation_count() -> usize {
ORIGINS_TRUNCATION_COUNT.load(std::sync::atomic::Ordering::Relaxed)
}
pub fn reset_origins_observability() {
ORIGINS_TRUNCATION_COUNT.store(0, std::sync::atomic::Ordering::Relaxed);
}
thread_local! {
static BODY_ENGINE_NOTES: RefCell<SmallVec<[crate::engine_notes::EngineNote; 2]>> =
RefCell::new(SmallVec::new());
static PATH_SAFE_SUPPRESSED_SPANS: RefCell<std::collections::HashSet<(usize, usize)>> =
RefCell::new(std::collections::HashSet::new());
static ALL_VALIDATED_SPANS: RefCell<std::collections::HashSet<(usize, usize)>> =
RefCell::new(std::collections::HashSet::new());
}
pub(crate) fn record_engine_note(note: crate::engine_notes::EngineNote) {
BODY_ENGINE_NOTES.with(|c| {
crate::engine_notes::push_unique(&mut c.borrow_mut(), note);
});
}
pub(crate) fn reset_body_engine_notes() {
BODY_ENGINE_NOTES.with(|c| c.borrow_mut().clear());
}
pub(crate) fn take_body_engine_notes() -> SmallVec<[crate::engine_notes::EngineNote; 2]> {
BODY_ENGINE_NOTES.with(|c| std::mem::take(&mut *c.borrow_mut()))
}
pub(crate) fn record_path_safe_suppressed_span(span: (usize, usize)) {
PATH_SAFE_SUPPRESSED_SPANS.with(|c| {
c.borrow_mut().insert(span);
});
}
pub fn reset_path_safe_suppressed_spans() {
PATH_SAFE_SUPPRESSED_SPANS.with(|c| c.borrow_mut().clear());
}
pub fn take_path_safe_suppressed_spans() -> std::collections::HashSet<(usize, usize)> {
PATH_SAFE_SUPPRESSED_SPANS.with(|c| std::mem::take(&mut *c.borrow_mut()))
}
pub(crate) fn record_all_validated_span(span: (usize, usize)) {
ALL_VALIDATED_SPANS.with(|c| {
c.borrow_mut().insert(span);
});
}
pub fn reset_all_validated_spans() {
ALL_VALIDATED_SPANS.with(|c| c.borrow_mut().clear());
}
pub fn take_all_validated_spans() -> std::collections::HashSet<(usize, usize)> {
ALL_VALIDATED_SPANS.with(|c| std::mem::take(&mut *c.borrow_mut()))
}
#[derive(Debug, Clone, Hash, Eq, PartialEq)]
pub struct BindingKey {
pub name: String,
pub body_id: BodyId,
}
impl BindingKey {
pub fn new(name: impl Into<String>, body_id: BodyId) -> Self {
Self {
name: name.into(),
body_id,
}
}
}
pub fn seed_lookup<'a>(
seed: &'a HashMap<BindingKey, VarTaint>,
key: &BindingKey,
) -> Option<&'a VarTaint> {
seed.get(key)
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct FieldTaintKey {
pub loc: LocId,
pub field: FieldId,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct FieldCell {
pub taint: VarTaint,
pub validated_must: bool,
pub validated_may: bool,
}
impl FieldCell {
pub fn unvalidated(taint: VarTaint) -> Self {
Self {
taint,
validated_must: false,
validated_may: false,
}
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct SsaTaintState {
pub values: SmallVec<[(SsaValue, VarTaint); 16]>,
pub validated_must: SmallBitSet,
pub validated_may: SmallBitSet,
pub predicates: SmallVec<[(SymbolId, PredicateSummary); 4]>,
pub heap: HeapState,
pub path_env: Option<constraint::PathEnv>,
pub abstract_state: Option<AbstractState>,
pub field_taint: SmallVec<[(FieldTaintKey, FieldCell); 4]>,
}
impl SsaTaintState {
pub fn initial() -> Self {
Self {
values: SmallVec::new(),
validated_must: SmallBitSet::empty(),
validated_may: SmallBitSet::empty(),
predicates: SmallVec::new(),
heap: HeapState::empty(),
path_env: if constraint::is_enabled() {
Some(constraint::PathEnv::empty())
} else {
None
},
abstract_state: if abstract_interp::is_enabled() {
Some(AbstractState::empty())
} else {
None
},
field_taint: SmallVec::new(),
}
}
pub fn get_field(&self, key: FieldTaintKey) -> Option<&FieldCell> {
self.field_taint
.binary_search_by_key(&key, |(k, _)| *k)
.ok()
.map(|idx| &self.field_taint[idx].1)
}
pub fn add_field(
&mut self,
key: FieldTaintKey,
t: VarTaint,
validated_must: bool,
validated_may: bool,
) {
if t.caps.is_empty() {
return;
}
match self.field_taint.binary_search_by_key(&key, |(k, _)| *k) {
Ok(idx) => {
let cell = &mut self.field_taint[idx].1;
cell.taint.caps |= t.caps;
cell.taint.uses_summary |= t.uses_summary;
let merged = merge_origins(&cell.taint.origins, &t.origins);
cell.taint.origins = merged;
cell.validated_must &= validated_must;
cell.validated_may |= validated_may;
}
Err(idx) => self.field_taint.insert(
idx,
(
key,
FieldCell {
taint: t,
validated_must,
validated_may,
},
),
),
}
}
pub fn has_contradiction(&self) -> bool {
self.predicates.iter().any(|(_, s)| s.has_contradiction())
|| self.path_env.as_ref().is_some_and(|e| e.is_unsat())
}
pub fn get(&self, v: SsaValue) -> Option<&VarTaint> {
self.values
.binary_search_by_key(&v, |(id, _)| *id)
.ok()
.map(|idx| &self.values[idx].1)
}
pub fn set(&mut self, v: SsaValue, taint: VarTaint) {
match self.values.binary_search_by_key(&v, |(id, _)| *id) {
Ok(idx) => self.values[idx].1 = taint,
Err(idx) => self.values.insert(idx, (v, taint)),
}
}
pub fn remove(&mut self, v: SsaValue) {
if let Ok(idx) = self.values.binary_search_by_key(&v, |(id, _)| *id) {
self.values.remove(idx);
}
}
}
impl Lattice for SsaTaintState {
fn bot() -> Self {
Self::initial()
}
fn join(&self, other: &Self) -> Self {
let values = merge_join_ssa_vars(&self.values, &other.values);
let validated_must = self.validated_must.intersection(other.validated_must);
let validated_may = self.validated_may.union(other.validated_may);
let predicates = merge_join_ssa_predicates(&self.predicates, &other.predicates);
let heap = self.heap.join(&other.heap);
let path_env = match (&self.path_env, &other.path_env) {
(Some(a), Some(b)) => Some(a.join(b)),
_ => None, };
let abstract_state = match (&self.abstract_state, &other.abstract_state) {
(Some(a), Some(b)) => Some(a.join(b)),
_ => None,
};
let field_taint = merge_join_field_taint(&self.field_taint, &other.field_taint);
SsaTaintState {
values,
validated_must,
validated_may,
predicates,
heap,
path_env,
abstract_state,
field_taint,
}
}
fn leq(&self, other: &Self) -> bool {
if !ssa_vars_leq(&self.values, &other.values) {
return false;
}
if !self.validated_must.is_superset_of(other.validated_must) {
return false;
}
if !self.validated_may.is_subset_of(other.validated_may) {
return false;
}
if !self.heap.leq(&other.heap) {
return false;
}
if !field_taint_leq(&self.field_taint, &other.field_taint) {
return false;
}
match (&self.path_env, &other.path_env) {
(None, Some(_)) => return false, (Some(_), None) => {} (None, None) => {}
(Some(a), Some(b)) => {
if a.fact_count() < b.fact_count() {
return false;
}
}
}
match (&self.abstract_state, &other.abstract_state) {
(None, Some(_)) => return false,
(Some(a), Some(b)) if !a.leq(b) => return false,
_ => {}
}
true
}
}
pub(super) fn merge_join_field_taint(
a: &[(FieldTaintKey, FieldCell)],
b: &[(FieldTaintKey, FieldCell)],
) -> SmallVec<[(FieldTaintKey, FieldCell); 4]> {
let mut result = SmallVec::with_capacity(a.len().max(b.len()));
let (mut i, mut j) = (0, 0);
while i < a.len() && j < b.len() {
match a[i].0.cmp(&b[j].0) {
std::cmp::Ordering::Less => {
let mut cell = a[i].1.clone();
cell.validated_must = false;
result.push((a[i].0, cell));
i += 1;
}
std::cmp::Ordering::Greater => {
let mut cell = b[j].1.clone();
cell.validated_must = false;
result.push((b[j].0, cell));
j += 1;
}
std::cmp::Ordering::Equal => {
let caps = a[i].1.taint.caps | b[j].1.taint.caps;
let origins = merge_origins(&a[i].1.taint.origins, &b[j].1.taint.origins);
let uses_summary = a[i].1.taint.uses_summary || b[j].1.taint.uses_summary;
let validated_must = a[i].1.validated_must && b[j].1.validated_must;
let validated_may = a[i].1.validated_may || b[j].1.validated_may;
result.push((
a[i].0,
FieldCell {
taint: VarTaint {
caps,
origins,
uses_summary,
},
validated_must,
validated_may,
},
));
i += 1;
j += 1;
}
}
}
while i < a.len() {
let mut cell = a[i].1.clone();
cell.validated_must = false;
result.push((a[i].0, cell));
i += 1;
}
while j < b.len() {
let mut cell = b[j].1.clone();
cell.validated_must = false;
result.push((b[j].0, cell));
j += 1;
}
result
}
pub(super) fn field_taint_leq(
a: &[(FieldTaintKey, FieldCell)],
b: &[(FieldTaintKey, FieldCell)],
) -> bool {
let mut j = 0;
for (key, ca) in a {
while j < b.len() && b[j].0 < *key {
j += 1;
}
if j >= b.len() || b[j].0 != *key {
if !ca.taint.caps.is_empty() || ca.validated_must {
return false;
}
continue;
}
let cb = &b[j].1;
if (ca.taint.caps - cb.taint.caps).bits() != 0 {
return false;
}
if cb.validated_must && !ca.validated_must {
return false;
}
if ca.validated_may && !cb.validated_may {
return false;
}
}
true
}
pub(super) fn merge_join_ssa_vars(
a: &[(SsaValue, VarTaint)],
b: &[(SsaValue, VarTaint)],
) -> SmallVec<[(SsaValue, VarTaint); 16]> {
let mut result = SmallVec::with_capacity(a.len().max(b.len()));
let (mut i, mut j) = (0, 0);
while i < a.len() && j < b.len() {
match a[i].0.cmp(&b[j].0) {
std::cmp::Ordering::Less => {
result.push(a[i].clone());
i += 1;
}
std::cmp::Ordering::Greater => {
result.push(b[j].clone());
j += 1;
}
std::cmp::Ordering::Equal => {
let caps = a[i].1.caps | b[j].1.caps;
let origins = merge_origins(&a[i].1.origins, &b[j].1.origins);
let uses_summary = a[i].1.uses_summary || b[j].1.uses_summary;
result.push((
a[i].0,
VarTaint {
caps,
origins,
uses_summary,
},
));
i += 1;
j += 1;
}
}
}
while i < a.len() {
result.push(a[i].clone());
i += 1;
}
while j < b.len() {
result.push(b[j].clone());
j += 1;
}
result
}
fn origin_sort_key(o: &TaintOrigin) -> (usize, usize, u8, usize) {
let (span_start, span_end) = o.source_span.unwrap_or((0, 0));
let kind_tag: u8 = match o.source_kind {
crate::labels::SourceKind::UserInput => 0,
crate::labels::SourceKind::EnvironmentConfig => 1,
crate::labels::SourceKind::FileSystem => 2,
crate::labels::SourceKind::Database => 3,
crate::labels::SourceKind::CaughtException => 4,
crate::labels::SourceKind::Unknown => 5,
};
(span_start, span_end, kind_tag, o.node.index())
}
pub(crate) fn push_origin_bounded(
target: &mut SmallVec<[TaintOrigin; 2]>,
new: TaintOrigin,
) -> bool {
if target.iter().any(|o| o.node == new.node) {
return true;
}
let cap = effective_max_origins();
let new_key = origin_sort_key(&new);
if target.len() < cap {
let pos = target
.iter()
.position(|o| origin_sort_key(o) > new_key)
.unwrap_or(target.len());
target.insert(pos, new);
return true;
}
let worst_idx = target
.iter()
.enumerate()
.max_by_key(|(_, o)| origin_sort_key(o))
.map(|(i, _)| i)
.expect("cap ≥ MIN_MAX_ORIGINS (1) means target is non-empty");
let worst_key = origin_sort_key(&target[worst_idx]);
ORIGINS_TRUNCATION_COUNT.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
record_engine_note(crate::engine_notes::EngineNote::OriginsTruncated { dropped: 1 });
if new_key < worst_key {
target.remove(worst_idx);
let pos = target
.iter()
.position(|o| origin_sort_key(o) > new_key)
.unwrap_or(target.len());
target.insert(pos, new);
true
} else {
false
}
}
pub(super) fn merge_origins(
a: &SmallVec<[TaintOrigin; 2]>,
b: &SmallVec<[TaintOrigin; 2]>,
) -> SmallVec<[TaintOrigin; 2]> {
let mut merged: SmallVec<[TaintOrigin; 2]> = SmallVec::new();
for o in a.iter().copied() {
push_origin_bounded(&mut merged, o);
}
for o in b.iter().copied() {
push_origin_bounded(&mut merged, o);
}
merged
}
#[allow(dead_code)] fn ssa_vars_leq(a: &[(SsaValue, VarTaint)], b: &[(SsaValue, VarTaint)]) -> bool {
let (mut i, mut j) = (0, 0);
while i < a.len() {
if j >= b.len() {
return false;
}
match a[i].0.cmp(&b[j].0) {
std::cmp::Ordering::Less => return false,
std::cmp::Ordering::Greater => {
j += 1;
}
std::cmp::Ordering::Equal => {
if a[i].1.caps & b[j].1.caps != a[i].1.caps {
return false;
}
if a[i].1.uses_summary && !b[j].1.uses_summary {
return false;
}
for orig in &a[i].1.origins {
if !b[j].1.origins.iter().any(|o| o.node == orig.node) {
return false;
}
}
i += 1;
j += 1;
}
}
}
true
}
pub(super) fn merge_join_ssa_predicates(
a: &[(SymbolId, PredicateSummary)],
b: &[(SymbolId, PredicateSummary)],
) -> SmallVec<[(SymbolId, PredicateSummary); 4]> {
let mut result = SmallVec::new();
let (mut i, mut j) = (0, 0);
while i < a.len() && j < b.len() {
match a[i].0.cmp(&b[j].0) {
std::cmp::Ordering::Less => {
i += 1;
}
std::cmp::Ordering::Greater => {
j += 1;
}
std::cmp::Ordering::Equal => {
let joined = a[i].1.join(b[j].1);
if !joined.is_empty() {
result.push((a[i].0, joined));
}
i += 1;
j += 1;
}
}
}
result
}
#[cfg(test)]
mod origin_cap_tests {
use super::*;
use crate::labels::SourceKind;
use petgraph::graph::NodeIndex;
use std::sync::Mutex;
static TEST_GUARD: Mutex<()> = Mutex::new(());
fn origin(node: usize, span_start: usize) -> TaintOrigin {
TaintOrigin {
node: NodeIndex::new(node),
source_kind: SourceKind::UserInput,
source_span: Some((span_start, span_start + 1)),
}
}
#[test]
fn push_origin_bounded_dedups_by_node() {
let _g = TEST_GUARD.lock().unwrap_or_else(|e| e.into_inner());
set_max_origins_override(4);
let mut target: SmallVec<[TaintOrigin; 2]> = SmallVec::new();
assert!(push_origin_bounded(&mut target, origin(1, 10)));
assert!(push_origin_bounded(&mut target, origin(1, 99))); assert_eq!(target.len(), 1, "duplicate node must not grow the set");
set_max_origins_override(0);
}
#[test]
fn push_origin_bounded_is_order_independent() {
let _g = TEST_GUARD.lock().unwrap_or_else(|e| e.into_inner());
set_max_origins_override(3);
let origins = [
origin(1, 50),
origin(2, 10), origin(3, 30),
origin(4, 70),
origin(5, 90), ];
let mut forward: SmallVec<[TaintOrigin; 2]> = SmallVec::new();
for o in origins.iter() {
push_origin_bounded(&mut forward, *o);
}
let mut reverse: SmallVec<[TaintOrigin; 2]> = SmallVec::new();
for o in origins.iter().rev() {
push_origin_bounded(&mut reverse, *o);
}
let forward_nodes: Vec<_> = forward.iter().map(|o| o.node.index()).collect();
let reverse_nodes: Vec<_> = reverse.iter().map(|o| o.node.index()).collect();
assert_eq!(
forward_nodes, reverse_nodes,
"survivor set must not depend on insertion order: forward {forward_nodes:?} \
reverse {reverse_nodes:?}"
);
assert_eq!(forward_nodes, vec![2, 3, 1]);
set_max_origins_override(0);
}
#[test]
fn push_origin_bounded_increments_truncation_counter() {
let _g = TEST_GUARD.lock().unwrap_or_else(|e| e.into_inner());
set_max_origins_override(2);
reset_origins_observability();
let mut target: SmallVec<[TaintOrigin; 2]> = SmallVec::new();
push_origin_bounded(&mut target, origin(1, 10));
push_origin_bounded(&mut target, origin(2, 20));
push_origin_bounded(&mut target, origin(3, 30));
push_origin_bounded(&mut target, origin(4, 40));
assert_eq!(
origins_truncation_count(),
2,
"expected 2 truncation events (3rd and 4th push at cap=2)"
);
set_max_origins_override(0);
reset_origins_observability();
}
#[test]
fn merge_origins_is_symmetric() {
let _g = TEST_GUARD.lock().unwrap_or_else(|e| e.into_inner());
set_max_origins_override(3);
let a: SmallVec<[TaintOrigin; 2]> = [origin(1, 100), origin(2, 200)].into_iter().collect();
let b: SmallVec<[TaintOrigin; 2]> = [origin(3, 10), origin(4, 50)].into_iter().collect();
let ab = merge_origins(&a, &b);
let ba = merge_origins(&b, &a);
let ab_nodes: Vec<_> = ab.iter().map(|o| o.node.index()).collect();
let ba_nodes: Vec<_> = ba.iter().map(|o| o.node.index()).collect();
assert_eq!(
ab_nodes, ba_nodes,
"merge must be commutative under truncation: ab={ab_nodes:?} ba={ba_nodes:?}"
);
set_max_origins_override(0);
}
#[test]
fn effective_cap_reads_runtime_config_when_override_zero() {
let _g = TEST_GUARD.lock().unwrap_or_else(|e| e.into_inner());
set_max_origins_override(0);
assert_eq!(
effective_max_origins(),
crate::utils::analysis_options::DEFAULT_MAX_ORIGINS as usize
);
set_max_origins_override(7);
assert_eq!(effective_max_origins(), 7);
set_max_origins_override(0);
}
}
#[cfg(test)]
mod field_taint_tests {
use super::*;
use crate::labels::Cap;
use crate::pointer::LocId;
use crate::ssa::ir::FieldId;
use crate::taint::domain::TaintOrigin;
use smallvec::SmallVec;
fn key(loc_raw: u32, field_raw: u32) -> FieldTaintKey {
FieldTaintKey {
loc: LocId(loc_raw),
field: FieldId(field_raw),
}
}
fn taint(caps: Cap) -> VarTaint {
VarTaint {
caps,
origins: SmallVec::new(),
uses_summary: false,
}
}
fn add(s: &mut SsaTaintState, k: FieldTaintKey, t: VarTaint) {
s.add_field(k, t, false, false);
}
#[test]
fn add_field_round_trips() {
let mut s = SsaTaintState::initial();
let k = key(1, 7);
add(&mut s, k, taint(Cap::ENV_VAR));
let got = s.get_field(k).expect("field cell present");
assert!(got.taint.caps.contains(Cap::ENV_VAR));
}
#[test]
fn add_field_unions_caps() {
let mut s = SsaTaintState::initial();
let k = key(1, 7);
add(&mut s, k, taint(Cap::ENV_VAR));
add(&mut s, k, taint(Cap::ENV_VAR));
let got = s.get_field(k).unwrap();
assert!(got.taint.caps.contains(Cap::ENV_VAR));
}
#[test]
fn add_field_skips_empty_caps() {
let mut s = SsaTaintState::initial();
let k = key(2, 3);
add(&mut s, k, taint(Cap::empty()));
assert!(s.get_field(k).is_none(), "empty caps must not insert");
}
#[test]
fn lattice_join_unions_keys_and_caps() {
let k1 = key(1, 7);
let k2 = key(2, 9);
let mut a = SsaTaintState::initial();
let mut b = SsaTaintState::initial();
add(&mut a, k1, taint(Cap::ENV_VAR));
add(&mut b, k1, taint(Cap::ENV_VAR));
add(&mut b, k2, taint(Cap::FILE_IO));
let joined = a.join(&b);
let v1 = joined.get_field(k1).unwrap();
assert!(v1.taint.caps.contains(Cap::ENV_VAR));
let v2 = joined.get_field(k2).unwrap();
assert!(v2.taint.caps.contains(Cap::FILE_IO));
}
#[test]
fn lattice_leq_detects_strict_increase() {
let mut b = SsaTaintState::initial();
add(&mut b, key(1, 7), taint(Cap::ENV_VAR));
let a = SsaTaintState::initial();
assert!(a.leq(&b), "empty state ≤ state with a field cell");
assert!(!b.leq(&a), "state with a field cell is NOT ≤ empty state");
}
#[test]
fn lattice_leq_holds_when_caps_subset() {
let k = key(3, 4);
let mut a = SsaTaintState::initial();
let mut b = SsaTaintState::initial();
add(&mut a, k, taint(Cap::ENV_VAR));
add(&mut b, k, taint(Cap::ENV_VAR | Cap::FILE_IO));
assert!(a.leq(&b));
assert!(!b.leq(&a));
}
#[test]
fn merge_origins_via_join_dedups_by_node() {
use petgraph::graph::NodeIndex;
let k = key(1, 1);
let o1 = TaintOrigin {
node: NodeIndex::new(5),
source_kind: crate::labels::SourceKind::UserInput,
source_span: Some((0, 1)),
};
let o2 = TaintOrigin {
node: NodeIndex::new(7),
source_kind: crate::labels::SourceKind::EnvironmentConfig,
source_span: Some((10, 11)),
};
let mut t1 = taint(Cap::ENV_VAR);
t1.origins.push(o1);
let mut t2 = taint(Cap::ENV_VAR);
t2.origins.push(o1);
t2.origins.push(o2);
let mut a = SsaTaintState::initial();
let mut b = SsaTaintState::initial();
add(&mut a, k, t1);
add(&mut b, k, t2);
let joined = a.join(&b);
let cell = joined.get_field(k).unwrap();
assert_eq!(cell.taint.origins.len(), 2);
let nodes: Vec<_> = cell.taint.origins.iter().map(|o| o.node).collect();
assert!(nodes.contains(&NodeIndex::new(5)));
assert!(nodes.contains(&NodeIndex::new(7)));
}
#[test]
fn lattice_validated_must_intersects_on_join() {
let k = key(1, 7);
let mut a = SsaTaintState::initial();
let mut b = SsaTaintState::initial();
a.add_field(k, taint(Cap::ENV_VAR), true, true);
b.add_field(k, taint(Cap::ENV_VAR), true, true);
let joined_aa = a.join(&b);
let cell = joined_aa.get_field(k).unwrap();
assert!(cell.validated_must, "a.must AND b.must = true");
assert!(cell.validated_may);
let mut c = SsaTaintState::initial();
c.add_field(k, taint(Cap::ENV_VAR), false, true);
let joined_ac = a.join(&c);
let cell2 = joined_ac.get_field(k).unwrap();
assert!(!cell2.validated_must, "a.must AND c.must = false");
assert!(cell2.validated_may, "a.may OR c.may = true");
}
#[test]
fn lattice_validated_may_unions_on_join() {
let k = key(1, 7);
let mut a = SsaTaintState::initial();
let mut b = SsaTaintState::initial();
a.add_field(k, taint(Cap::ENV_VAR), false, false);
b.add_field(k, taint(Cap::ENV_VAR), false, true);
let joined = a.join(&b);
let cell = joined.get_field(k).unwrap();
assert!(!cell.validated_must);
assert!(cell.validated_may, "a.may OR b.may = true");
}
#[test]
fn lattice_validated_consistent_with_taint_join() {
let k = key(2, 11);
let mut a = SsaTaintState::initial();
let b = SsaTaintState::initial();
a.add_field(k, taint(Cap::ENV_VAR), true, true);
let joined = a.join(&b);
let cell = joined.get_field(k).unwrap();
assert!(
!cell.validated_must,
"joined with empty side must drop validated_must"
);
assert!(
cell.validated_may,
"joined with empty side keeps validated_may"
);
assert!(cell.taint.caps.contains(Cap::ENV_VAR));
let joined2 = b.join(&a);
let cell2 = joined2.get_field(k).unwrap();
assert!(!cell2.validated_must);
assert!(cell2.validated_may);
}
#[test]
fn lattice_leq_respects_validated_channels() {
let k = key(3, 5);
let mut a = SsaTaintState::initial();
let mut b = SsaTaintState::initial();
a.add_field(k, taint(Cap::ENV_VAR), true, false);
b.add_field(k, taint(Cap::ENV_VAR), false, false);
assert!(
a.leq(&b),
"must super-state and equal caps: a ≤ b should hold"
);
assert!(!b.leq(&a), "b lacks the must invariant a holds");
let mut a2 = SsaTaintState::initial();
let mut b2 = SsaTaintState::initial();
a2.add_field(k, taint(Cap::ENV_VAR), false, true);
b2.add_field(k, taint(Cap::ENV_VAR), false, false);
assert!(!a2.leq(&b2), "a.may=true is NOT ⊆ b.may=false");
}
#[test]
fn lattice_converges_under_deterministic_enumeration() {
use crate::labels::Cap;
use petgraph::graph::NodeIndex;
let inputs: Vec<(FieldTaintKey, VarTaint)> = (0..6)
.map(|i| {
let key = FieldTaintKey {
loc: LocId(1 + (i % 3) as u32),
field: FieldId((i % 4) as u32),
};
let taint = VarTaint {
caps: if i % 2 == 0 {
Cap::ENV_VAR
} else {
Cap::FILE_IO
},
origins: smallvec::SmallVec::from_iter([TaintOrigin {
node: NodeIndex::new(i + 10),
source_kind: crate::labels::SourceKind::UserInput,
source_span: Some((i * 5, i * 5 + 2)),
}]),
uses_summary: false,
};
(key, taint)
})
.collect();
let states: Vec<SsaTaintState> = inputs
.iter()
.map(|(k, t)| {
let mut s = SsaTaintState::initial();
add(&mut s, *k, t.clone());
s
})
.collect();
let lub = states
.iter()
.skip(1)
.fold(states[0].clone(), |acc, s| acc.join(s));
for i in 0..states.len() {
for j in (i + 1)..states.len() {
let ab = states[i].join(&states[j]);
let ba = states[j].join(&states[i]);
assert_eq!(
ab, ba,
"join must commute: states[{i}] ⊕ states[{j}] != states[{j}] ⊕ states[{i}]",
);
}
}
for i in 0..states.len() {
for j in 0..states.len() {
for k in 0..states.len() {
let a = &states[i];
let b = &states[j];
let c = &states[k];
let left = a.join(b).join(c);
let right = a.join(&b.join(c));
assert_eq!(
left, right,
"join must associate: states[{i},{j},{k}] left vs right",
);
}
}
}
let lub_lub = lub.join(&lub);
assert_eq!(lub, lub_lub, "lub must be idempotent under self-join");
for (i, s) in states.iter().enumerate() {
let merged = lub.join(s);
assert_eq!(
lub, merged,
"lub.join(states[{i}]) must equal lub (s ≤ lub)",
);
}
let mut acc = SsaTaintState::initial();
let mut iter_count = 0;
loop {
iter_count += 1;
if iter_count > inputs.len() + 4 {
panic!("lattice did not converge within bounded iterations");
}
let mut next = acc.clone();
for s in &states {
next = next.join(s);
}
if next.field_taint == acc.field_taint {
break;
}
acc = next;
}
assert_eq!(
acc, lub,
"iterative fold must converge to the lub regardless of order",
);
}
#[test]
fn lattice_leq_consistent_with_join() {
use crate::labels::Cap;
let mut a = SsaTaintState::initial();
let mut b = SsaTaintState::initial();
add(&mut a, key(1, 7), taint(Cap::ENV_VAR));
add(&mut b, key(1, 7), taint(Cap::FILE_IO));
add(&mut b, key(2, 9), taint(Cap::SHELL_ESCAPE));
let j = a.join(&b);
assert!(a.leq(&j), "a ≤ a ⊕ b");
assert!(b.leq(&j), "b ≤ a ⊕ b");
assert!(a.leq(&a));
assert!(b.leq(&b));
assert!(j.leq(&j));
}
}