use std::collections::BTreeMap;
use serde::{Deserialize, Serialize};
use crate::flow::TaintKind;
#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
pub struct RoutineFlowSummary {
pub logical_id: String,
pub param_taints: BTreeMap<usize, Vec<TaintKind>>,
pub returns_taint: Vec<TaintKind>,
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct CallEdgeFlow {
pub caller: String,
pub callee: String,
pub actual_arg_taints: Vec<Vec<TaintKind>>,
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct FlowUnknownFact {
pub at_caller: String,
pub callee: String,
pub reason: FlowUnknownReason,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum FlowUnknownReason {
MissingCalleeSummary,
RecursionCycle,
}
#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
pub struct InterFlowResult {
pub propagated_returns: Vec<PropagatedReturn>,
pub unknowns: Vec<FlowUnknownFact>,
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct PropagatedReturn {
pub caller: String,
pub callee: String,
pub result_taint: Vec<TaintKind>,
}
#[must_use]
pub fn propagate_inter(
call_edges: &[CallEdgeFlow],
summaries: &[RoutineFlowSummary],
) -> InterFlowResult {
let by_id: BTreeMap<&str, &RoutineFlowSummary> = summaries
.iter()
.map(|s| (s.logical_id.as_str(), s))
.collect();
let mut result = InterFlowResult::default();
for edge in call_edges {
resolve_edge(edge, &by_id, &mut result);
}
result
}
fn resolve_edge(
edge: &CallEdgeFlow,
by_id: &BTreeMap<&str, &RoutineFlowSummary>,
result: &mut InterFlowResult,
) {
if edge.callee == edge.caller {
result.unknowns.push(FlowUnknownFact {
at_caller: edge.caller.clone(),
callee: edge.callee.clone(),
reason: FlowUnknownReason::RecursionCycle,
});
return;
}
let Some(summary) = by_id.get(edge.callee.as_str()) else {
result.unknowns.push(FlowUnknownFact {
at_caller: edge.caller.clone(),
callee: edge.callee.clone(),
reason: FlowUnknownReason::MissingCalleeSummary,
});
return;
};
let mut result_taint: Vec<TaintKind> = summary.returns_taint.clone();
for (idx, actual) in edge.actual_arg_taints.iter().enumerate() {
if let Some(param_kinds) = summary.param_taints.get(&idx)
&& !param_kinds.is_empty()
{
for k in actual {
if !result_taint.contains(k) {
result_taint.push(*k);
}
}
}
}
result.propagated_returns.push(PropagatedReturn {
caller: edge.caller.clone(),
callee: edge.callee.clone(),
result_taint,
});
}
#[cfg(test)]
mod tests {
use super::*;
fn summ(id: &str, params: &[(usize, &[TaintKind])], ret: &[TaintKind]) -> RoutineFlowSummary {
let mut pt = BTreeMap::new();
for (i, ks) in params {
pt.insert(*i, ks.to_vec());
}
RoutineFlowSummary {
logical_id: id.into(),
param_taints: pt,
returns_taint: ret.to_vec(),
}
}
#[test]
fn taint_flows_through_propagating_param_to_result() {
let edges = vec![CallEdgeFlow {
caller: "a".into(),
callee: "b".into(),
actual_arg_taints: vec![vec![TaintKind::UserInput]],
}];
let summaries = vec![summ("b", &[(0, &[TaintKind::UserInput])], &[])];
let r = propagate_inter(&edges, &summaries);
assert_eq!(r.propagated_returns.len(), 1);
assert!(
r.propagated_returns[0]
.result_taint
.contains(&TaintKind::UserInput)
);
assert!(r.unknowns.is_empty());
}
#[test]
fn non_propagating_param_does_not_taint_result() {
let edges = vec![CallEdgeFlow {
caller: "a".into(),
callee: "b".into(),
actual_arg_taints: vec![vec![TaintKind::UserInput]],
}];
let summaries = vec![summ("b", &[], &[])];
let r = propagate_inter(&edges, &summaries);
assert!(r.propagated_returns[0].result_taint.is_empty());
}
#[test]
fn declared_return_taint_always_present() {
let edges = vec![CallEdgeFlow {
caller: "a".into(),
callee: "b".into(),
actual_arg_taints: vec![],
}];
let summaries = vec![summ("b", &[], &[TaintKind::DbLink])];
let r = propagate_inter(&edges, &summaries);
assert!(
r.propagated_returns[0]
.result_taint
.contains(&TaintKind::DbLink)
);
}
#[test]
fn missing_summary_records_unknown() {
let edges = vec![CallEdgeFlow {
caller: "a".into(),
callee: "external_pkg.proc".into(),
actual_arg_taints: vec![],
}];
let r = propagate_inter(&edges, &[]);
assert_eq!(r.unknowns.len(), 1);
assert_eq!(
r.unknowns[0].reason,
FlowUnknownReason::MissingCalleeSummary
);
}
#[test]
fn direct_recursion_records_cycle_unknown() {
let edges = vec![CallEdgeFlow {
caller: "rec".into(),
callee: "rec".into(),
actual_arg_taints: vec![],
}];
let summaries = vec![summ("rec", &[], &[])];
let r = propagate_inter(&edges, &summaries);
assert_eq!(r.unknowns[0].reason, FlowUnknownReason::RecursionCycle);
}
#[test]
fn multiple_taint_kinds_union_into_result() {
let edges = vec![CallEdgeFlow {
caller: "a".into(),
callee: "b".into(),
actual_arg_taints: vec![vec![TaintKind::UserInput, TaintKind::BindVariable]],
}];
let summaries = vec![summ("b", &[(0, &[TaintKind::UserInput])], &[])];
let r = propagate_inter(&edges, &summaries);
let t = &r.propagated_returns[0].result_taint;
assert!(t.contains(&TaintKind::UserInput));
assert!(t.contains(&TaintKind::BindVariable));
}
#[test]
fn result_taint_dedupes() {
let edges = vec![CallEdgeFlow {
caller: "a".into(),
callee: "b".into(),
actual_arg_taints: vec![vec![TaintKind::UserInput]],
}];
let summaries = vec![summ(
"b",
&[(0, &[TaintKind::UserInput])],
&[TaintKind::UserInput],
)];
let r = propagate_inter(&edges, &summaries);
let count = r.propagated_returns[0]
.result_taint
.iter()
.filter(|k| **k == TaintKind::UserInput)
.count();
assert_eq!(count, 1);
}
#[test]
fn serde_round_trip() {
let edges = vec![CallEdgeFlow {
caller: "a".into(),
callee: "missing".into(),
actual_arg_taints: vec![],
}];
let r = propagate_inter(&edges, &[]);
let json = serde_json::to_string(&r).unwrap();
let back: InterFlowResult = serde_json::from_str(&json).unwrap();
assert_eq!(back, r);
assert!(json.contains("missing_callee_summary"));
}
#[test]
fn chain_is_resolved_single_hop_not_transitively() {
let edges = vec![
CallEdgeFlow {
caller: "a".into(),
callee: "b".into(),
actual_arg_taints: vec![vec![TaintKind::UserInput]],
},
CallEdgeFlow {
caller: "b".into(),
callee: "c".into(),
actual_arg_taints: vec![vec![TaintKind::UserInput]],
},
];
let summaries = vec![
summ("b", &[(0, &[TaintKind::UserInput])], &[]),
summ("c", &[], &[TaintKind::DbLink]),
];
let r = propagate_inter(&edges, &summaries);
assert!(r.unknowns.is_empty());
assert_eq!(r.propagated_returns.len(), 2);
let a_rec = r
.propagated_returns
.iter()
.find(|p| p.caller == "a")
.expect("a→b record present");
assert!(
a_rec.result_taint.contains(&TaintKind::UserInput),
"a→b folds the actual's UserInput through b's propagating param"
);
assert!(
!a_rec.result_taint.contains(&TaintKind::DbLink),
"single-hop: c's DbLink must NOT transitively reach a"
);
let b_rec = r
.propagated_returns
.iter()
.find(|p| p.caller == "b")
.expect("b→c record present");
assert!(
b_rec.result_taint.contains(&TaintKind::DbLink),
"b→c carries c's declared return taint to b"
);
}
#[test]
fn distinct_caller_callee_with_same_name_in_two_edges_is_not_a_cycle() {
let edges = vec![
CallEdgeFlow {
caller: "a".into(),
callee: "b".into(),
actual_arg_taints: vec![],
},
CallEdgeFlow {
caller: "b".into(),
callee: "a".into(),
actual_arg_taints: vec![],
},
];
let summaries = vec![summ("a", &[], &[]), summ("b", &[], &[])];
let r = propagate_inter(&edges, &summaries);
assert!(
r.unknowns.is_empty(),
"mutual edges are single-hop resolvable, not direct self-recursion"
);
assert_eq!(r.propagated_returns.len(), 2);
}
}