use std::collections::HashSet;
use crate::error::{Diagnostics, Error};
use crate::ir::{Node, NodeId, SymbolicFlowGraph};
use crate::phase::{Phase, PhaseNodeKind, Transition, transition};
pub fn validate(graph: &SymbolicFlowGraph) -> Result<(), Error> {
validate_collecting(graph).into_result(()).map_err(Error::from)
}
#[must_use]
pub fn validate_collecting(graph: &SymbolicFlowGraph) -> Diagnostics {
let mut d = Diagnostics::new();
check_id_ranges(graph, &mut d);
if d.is_empty() {
check_fetch_edges(graph, &mut d);
check_acyclic(graph, &mut d);
check_phases_collecting(graph, &mut d);
}
d
}
fn check_id_ranges(graph: &SymbolicFlowGraph, d: &mut Diagnostics) {
let n_nodes = u32::try_from(graph.nodes.len()).unwrap_or(u32::MAX);
let n_preds = u32::try_from(graph.predicates.len()).unwrap_or(u32::MAX);
let n_mws = u32::try_from(graph.middlewares.len()).unwrap_or(u32::MAX);
let n_fetches = u32::try_from(graph.fetches.len()).unwrap_or(u32::MAX);
let n_terms = u32::try_from(graph.terminators.len()).unwrap_or(u32::MAX);
for (idx, node) in graph.nodes.iter().enumerate() {
match node {
Node::Check { predicate, on_match, on_miss, .. } => {
if predicate.get() >= n_preds {
d.push(Error::compile(format!("node {idx}: dangling PredicateId({})", predicate.get())));
}
if on_match.get() >= n_nodes {
d.push(Error::compile(format!("node {idx}.on_match dangling")));
}
if on_miss.get() >= n_nodes {
d.push(Error::compile(format!("node {idx}.on_miss dangling")));
}
}
Node::Middleware { id, next, on_error, .. } => {
if id.get() >= n_mws {
d.push(Error::compile(format!("node {idx}: dangling MiddlewareId({})", id.get())));
}
if next.get() >= n_nodes {
d.push(Error::compile(format!("node {idx}.next dangling")));
}
if let Some(e) = on_error
&& e.get() >= n_nodes
{
d.push(Error::compile(format!("node {idx}.on_error dangling")));
}
}
Node::Fetch { id, next_response, next_tunnel, .. } => {
if id.get() >= n_fetches {
d.push(Error::compile(format!("node {idx}: dangling FetchId({})", id.get())));
}
if let Some(r) = next_response
&& r.get() >= n_nodes
{
d.push(Error::compile(format!("node {idx}.next_response dangling")));
}
if let Some(t) = next_tunnel
&& t.get() >= n_nodes
{
d.push(Error::compile(format!("node {idx}.next_tunnel dangling")));
}
}
Node::Upgrade { next } => {
if next.get() >= n_nodes {
d.push(Error::compile(format!("node {idx}.next dangling")));
}
}
Node::Terminate(t) => {
if t.get() >= n_terms {
d.push(Error::compile(format!("node {idx}: dangling TerminatorId({})", t.get())));
}
}
}
}
}
fn check_fetch_edges(graph: &SymbolicFlowGraph, d: &mut Diagnostics) {
use crate::fetch::FetchKind::{
AcmeChallenge, HttpProxy, HttpSynthesize, L4Forward, WebSocketUpgrade,
};
for (idx, node) in graph.nodes.iter().enumerate() {
let Node::Fetch { id, next_response, next_tunnel, .. } = node else {
continue;
};
let kind = graph[*id].kind;
match kind {
HttpProxy | HttpSynthesize | AcmeChallenge => {
if next_response.is_none() {
d.push(Error::compile(format!("node {idx}: {kind:?} requires next_response")));
}
if next_tunnel.is_some() {
d.push(Error::compile(format!("node {idx}: {kind:?} must not have next_tunnel")));
}
}
L4Forward => {
if next_tunnel.is_none() {
d.push(Error::compile(format!("node {idx}: L4Forward requires next_tunnel")));
}
if next_response.is_some() {
d.push(Error::compile(format!("node {idx}: L4Forward must not have next_response")));
}
}
WebSocketUpgrade => {
if next_response.is_none() || next_tunnel.is_none() {
d.push(Error::compile(format!(
"node {idx}: WebSocketUpgrade requires both next_response and next_tunnel"
)));
}
}
}
}
}
fn check_acyclic(graph: &SymbolicFlowGraph, d: &mut Diagnostics) {
#[derive(Copy, Clone)]
enum Color {
White,
Gray,
Black,
}
let mut color: Vec<Color> = (0..graph.nodes.len()).map(|_| Color::White).collect();
let mut reported: HashSet<usize> = HashSet::new();
for start in 0..graph.nodes.len() {
if !matches!(color[start], Color::White) {
continue;
}
let mut stack: Vec<(usize, usize)> = vec![(start, 0)];
color[start] = Color::Gray;
while let Some(&(node_idx, child_idx)) = stack.last() {
let succs = successors(&graph.nodes[node_idx]);
if child_idx < succs.len() {
let next = succs[child_idx].get() as usize;
stack.last_mut().expect("non-empty").1 += 1;
match color[next] {
Color::White => {
color[next] = Color::Gray;
stack.push((next, 0));
}
Color::Gray => {
if reported.insert(next) {
d.push(Error::compile(format!("cycle in graph at node {next}")));
}
}
Color::Black => {}
}
} else {
color[node_idx] = Color::Black;
stack.pop();
}
}
}
}
fn successors(node: &Node) -> Vec<NodeId> {
match node {
Node::Check { on_match, on_miss, .. } => vec![*on_match, *on_miss],
Node::Middleware { next, on_error, .. } => {
let mut v = vec![*next];
if let Some(e) = on_error {
v.push(*e);
}
v
}
Node::Fetch { next_response, next_tunnel, .. } => {
let mut v = Vec::new();
if let Some(r) = next_response {
v.push(*r);
}
if let Some(t) = next_tunnel {
v.push(*t);
}
v
}
Node::Upgrade { next } => vec![*next],
Node::Terminate(_) => Vec::new(),
}
}
fn node_kind_for_phase(graph: &SymbolicFlowGraph, node: &Node) -> PhaseNodeKind {
match node {
Node::Check { .. } => PhaseNodeKind::Check,
Node::Middleware { id, .. } => PhaseNodeKind::Middleware(graph[*id].kind),
Node::Fetch { id, .. } => PhaseNodeKind::Fetch(graph[*id].kind),
Node::Upgrade { .. } => PhaseNodeKind::Upgrade,
Node::Terminate(t) => PhaseNodeKind::Terminate(graph[*t]),
}
}
pub fn check_phases(graph: &SymbolicFlowGraph) -> Result<(), Error> {
let mut d = Diagnostics::new();
check_phases_collecting(graph, &mut d);
d.into_result(()).map_err(Error::from)
}
fn check_phases_collecting(graph: &SymbolicFlowGraph, d: &mut Diagnostics) {
let mut seen: HashSet<(NodeId, Phase)> = HashSet::new();
for &entry in graph.entries.values() {
if let Err(e) = visit_phase(graph, entry, Phase::L4Raw, &mut seen) {
d.push(e);
}
}
for &synth in graph.meta.short_circuit_response_entry.values() {
if let Err(e) = visit_phase(graph, synth, Phase::L7Response, &mut seen) {
d.push(e);
}
}
}
fn visit_phase(
graph: &SymbolicFlowGraph,
id: NodeId,
phase: Phase,
seen: &mut HashSet<(NodeId, Phase)>,
) -> Result<(), Error> {
if !seen.insert((id, phase)) {
return Ok(());
}
let node = &graph[id];
let kind = node_kind_for_phase(graph, node);
let t = transition(kind, phase).map_err(|e| {
Error::compile(format!(
"phase mismatch at NodeId({}): expected one of {:?}, got {:?}",
id.get(),
e.expected,
e.got,
))
})?;
match (t, node) {
(Transition::Terminal, _) => Ok(()),
(Transition::PassThrough, _) => {
for succ in successors(node) {
visit_phase(graph, succ, phase, seen)?;
}
Ok(())
}
(Transition::Into(next_phase), _) => {
for succ in successors(node) {
visit_phase(graph, succ, next_phase, seen)?;
}
Ok(())
}
(
Transition::BiOutcome { response, tunnel },
Node::Fetch { next_response, next_tunnel, .. },
) => {
if let Some(r) = next_response {
visit_phase(graph, *r, response, seen)?;
}
if let Some(t) = next_tunnel {
visit_phase(graph, *t, tunnel, seen)?;
}
Ok(())
}
(Transition::BiOutcome { .. }, _) => {
Err(Error::compile("BiOutcome transition on non-Fetch node".to_string()))
}
}
}
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use std::path::PathBuf;
use std::time::SystemTime;
use super::*;
use crate::fetch::{FetchKind, SymbolicFetchRef, Terminator};
use crate::ir::{BodySide, FetchId, FlowGraphMeta, PredicateId, TerminatorId};
fn empty_meta() -> FlowGraphMeta {
FlowGraphMeta {
version_hash: [0; 32],
compiled_at: SystemTime::UNIX_EPOCH,
source_files: vec![PathBuf::new()],
feature_set: &[],
short_circuit_response_entry: std::collections::BTreeMap::new(),
listener_tls: std::collections::BTreeMap::new(),
listener_kinds: std::collections::BTreeMap::new(),
listener_transports: std::collections::BTreeMap::new(),
annotations: Vec::new(),
}
}
#[test]
fn validate_collecting_accumulates_every_dangling_check_edge_in_one_pass() {
let graph = SymbolicFlowGraph {
nodes: vec![
Node::Check {
predicate: PredicateId::new(0),
on_match: NodeId::new(50),
on_miss: NodeId::new(51),
collect_body_before: None,
body_limit: 0,
},
Node::Check {
predicate: PredicateId::new(0),
on_match: NodeId::new(52),
on_miss: NodeId::new(53),
collect_body_before: None,
body_limit: 0,
},
],
predicates: vec![dummy_predicate()],
middlewares: vec![],
fetches: vec![],
terminators: vec![],
entries: HashMap::new(),
meta: empty_meta(),
};
let d = validate_collecting(&graph);
assert_eq!(d.len(), 4, "expected one error per dangling edge: {d}");
let dump = d.to_string();
assert!(dump.contains("on_match dangling"), "{dump}");
assert!(dump.contains("on_miss dangling"), "{dump}");
}
#[test]
fn dangling_terminator_id_in_terminate_node_rejected() {
let graph = SymbolicFlowGraph {
nodes: vec![Node::Terminate(TerminatorId::new(0))],
predicates: vec![],
middlewares: vec![],
fetches: vec![],
terminators: vec![],
entries: HashMap::new(),
meta: empty_meta(),
};
let err = validate(&graph).expect_err("must error");
assert!(err.to_string().contains("dangling TerminatorId"));
}
#[test]
fn dangling_node_id_in_fetch_edge_rejected() {
let graph = SymbolicFlowGraph {
nodes: vec![Node::Fetch {
id: FetchId::new(0),
next_response: Some(NodeId::new(99)),
next_tunnel: None,
collect_body_before: None,
body_limit: 0,
}],
predicates: vec![],
middlewares: vec![],
fetches: vec![SymbolicFetchRef {
kind: FetchKind::HttpProxy,
args: serde_json::Value::Null,
retry_buffer_required: false,
allow_zero_rtt: None,
}],
terminators: vec![],
entries: HashMap::new(),
meta: empty_meta(),
};
let err = validate(&graph).expect_err("must error");
assert!(err.to_string().contains("next_response dangling"));
}
#[test]
fn http_fetch_without_next_response_rejected() {
let term = Node::Terminate(TerminatorId::new(0));
let graph = SymbolicFlowGraph {
nodes: vec![
term,
Node::Fetch {
id: FetchId::new(0),
next_response: None,
next_tunnel: None,
collect_body_before: None,
body_limit: 0,
},
],
predicates: vec![],
middlewares: vec![],
fetches: vec![SymbolicFetchRef {
kind: FetchKind::HttpProxy,
args: serde_json::Value::Null,
retry_buffer_required: false,
allow_zero_rtt: None,
}],
terminators: vec![Terminator::WriteHttpResponse],
entries: HashMap::new(),
meta: empty_meta(),
};
let err = validate(&graph).expect_err("must error");
assert!(err.to_string().contains("requires next_response"));
}
#[test]
fn l4_forward_with_next_response_rejected() {
let graph = SymbolicFlowGraph {
nodes: vec![
Node::Terminate(TerminatorId::new(0)),
Node::Fetch {
id: FetchId::new(0),
next_response: Some(NodeId::new(0)),
next_tunnel: Some(NodeId::new(0)),
collect_body_before: None,
body_limit: 0,
},
],
predicates: vec![],
middlewares: vec![],
fetches: vec![SymbolicFetchRef {
kind: FetchKind::L4Forward,
args: serde_json::Value::Null,
retry_buffer_required: false,
allow_zero_rtt: None,
}],
terminators: vec![Terminator::ByteTunnel],
entries: HashMap::new(),
meta: empty_meta(),
};
let err = validate(&graph).expect_err("must error");
assert!(err.to_string().contains("L4Forward must not have next_response"));
}
#[test]
fn cyclic_graph_is_rejected() {
let graph = SymbolicFlowGraph {
nodes: vec![
Node::Check {
predicate: PredicateId::new(0),
on_match: NodeId::new(1),
on_miss: NodeId::new(1),
collect_body_before: None,
body_limit: 0,
},
Node::Check {
predicate: PredicateId::new(0),
on_match: NodeId::new(0),
on_miss: NodeId::new(0),
collect_body_before: None,
body_limit: 0,
},
],
predicates: vec![dummy_predicate()],
middlewares: vec![],
fetches: vec![],
terminators: vec![],
entries: HashMap::new(),
meta: empty_meta(),
};
let err = validate(&graph).expect_err("must error");
assert!(err.to_string().contains("cycle"));
}
#[test]
fn phase_check_rejects_write_http_response_reached_in_wrong_phase() {
let tid = TerminatorId::new(0);
let graph = SymbolicFlowGraph {
nodes: vec![Node::Terminate(tid), Node::Upgrade { next: NodeId::new(0) }],
predicates: vec![],
middlewares: vec![],
fetches: vec![],
terminators: vec![Terminator::WriteHttpResponse],
entries: {
let mut m = HashMap::new();
m.insert("127.0.0.1:443".parse().expect("parse"), NodeId::new(1));
m
},
meta: empty_meta(),
};
let err = check_phases(&graph).expect_err("must error");
assert!(err.to_string().contains("phase mismatch"));
}
#[test]
fn phase_check_rejects_short_circuit_synth_with_wrong_terminator() {
let bad_tid = TerminatorId::new(0);
let mut meta = empty_meta();
meta.short_circuit_response_entry.insert(NodeId::new(1), NodeId::new(0));
let graph = SymbolicFlowGraph {
nodes: vec![Node::Terminate(bad_tid), Node::Upgrade { next: NodeId::new(0) }],
predicates: vec![],
middlewares: vec![],
fetches: vec![],
terminators: vec![Terminator::ByteTunnel],
entries: HashMap::new(),
meta,
};
let err = check_phases(&graph).expect_err("must error on bad synth phase");
assert!(err.to_string().contains("phase mismatch"), "{err}");
}
fn dummy_predicate() -> crate::predicate::PredicateInst {
use crate::predicate::{CompiledOperator, CompiledValue, FieldPath, PredicateInst};
PredicateInst {
path: FieldPath::TlsSni,
op: CompiledOperator::Equals(CompiledValue::Str(std::sync::Arc::from("x"))),
}
}
use crate::middleware::{MiddlewareKind, SymbolicMiddlewareRef};
fn http_fetch_ref() -> SymbolicFetchRef {
SymbolicFetchRef {
kind: FetchKind::HttpProxy,
args: serde_json::Value::Null,
retry_buffer_required: false,
allow_zero_rtt: None,
}
}
fn ws_fetch_ref() -> SymbolicFetchRef {
SymbolicFetchRef {
kind: FetchKind::WebSocketUpgrade,
args: serde_json::Value::Null,
retry_buffer_required: false,
allow_zero_rtt: None,
}
}
fn l4_fetch_ref() -> SymbolicFetchRef {
SymbolicFetchRef {
kind: FetchKind::L4Forward,
args: serde_json::Value::Null,
retry_buffer_required: false,
allow_zero_rtt: None,
}
}
fn dummy_middleware_ref() -> SymbolicMiddlewareRef {
SymbolicMiddlewareRef {
name: std::sync::Arc::from("noop"),
args: serde_json::Value::Null,
kind: MiddlewareKind::L4Peek,
stateless: true,
needs_body: false,
on_error: None,
}
}
fn assert_err_contains(graph: &SymbolicFlowGraph, needle: &str) {
let err = validate(graph).expect_err("must error");
let msg = err.to_string();
assert!(msg.contains(needle), "expected {needle:?} in error, got: {msg}");
}
#[test]
fn validate_rejects_dangling_predicate_id_in_check() {
let graph = SymbolicFlowGraph {
nodes: vec![Node::Check {
predicate: PredicateId::new(7),
on_match: NodeId::new(0),
on_miss: NodeId::new(0),
collect_body_before: None,
body_limit: 0,
}],
predicates: vec![],
middlewares: vec![],
fetches: vec![],
terminators: vec![],
entries: HashMap::new(),
meta: empty_meta(),
};
assert_err_contains(&graph, "dangling PredicateId");
}
#[test]
fn validate_rejects_dangling_on_match_in_check() {
let graph = SymbolicFlowGraph {
nodes: vec![Node::Check {
predicate: PredicateId::new(0),
on_match: NodeId::new(42),
on_miss: NodeId::new(0),
collect_body_before: None,
body_limit: 0,
}],
predicates: vec![dummy_predicate()],
middlewares: vec![],
fetches: vec![],
terminators: vec![],
entries: HashMap::new(),
meta: empty_meta(),
};
assert_err_contains(&graph, "on_match dangling");
}
#[test]
fn validate_rejects_dangling_on_miss_in_check() {
let graph = SymbolicFlowGraph {
nodes: vec![Node::Check {
predicate: PredicateId::new(0),
on_match: NodeId::new(0),
on_miss: NodeId::new(42),
collect_body_before: None,
body_limit: 0,
}],
predicates: vec![dummy_predicate()],
middlewares: vec![],
fetches: vec![],
terminators: vec![],
entries: HashMap::new(),
meta: empty_meta(),
};
assert_err_contains(&graph, "on_miss dangling");
}
#[test]
fn validate_rejects_dangling_middleware_id() {
let graph = SymbolicFlowGraph {
nodes: vec![Node::Middleware {
id: crate::ir::MiddlewareId::new(7),
next: NodeId::new(0),
on_error: None,
collect_body_before: None,
body_limit: 0,
}],
predicates: vec![],
middlewares: vec![],
fetches: vec![],
terminators: vec![],
entries: HashMap::new(),
meta: empty_meta(),
};
assert_err_contains(&graph, "dangling MiddlewareId");
}
#[test]
fn validate_rejects_dangling_next_in_middleware() {
let graph = SymbolicFlowGraph {
nodes: vec![Node::Middleware {
id: crate::ir::MiddlewareId::new(0),
next: NodeId::new(42),
on_error: None,
collect_body_before: None,
body_limit: 0,
}],
predicates: vec![],
middlewares: vec![dummy_middleware_ref()],
fetches: vec![],
terminators: vec![],
entries: HashMap::new(),
meta: empty_meta(),
};
assert_err_contains(&graph, "next dangling");
}
#[test]
fn validate_rejects_dangling_on_error_in_middleware() {
let graph = SymbolicFlowGraph {
nodes: vec![Node::Middleware {
id: crate::ir::MiddlewareId::new(0),
next: NodeId::new(0),
on_error: Some(NodeId::new(42)),
collect_body_before: None,
body_limit: 0,
}],
predicates: vec![],
middlewares: vec![dummy_middleware_ref()],
fetches: vec![],
terminators: vec![],
entries: HashMap::new(),
meta: empty_meta(),
};
assert_err_contains(&graph, "on_error dangling");
}
#[test]
fn validate_rejects_dangling_fetch_id() {
let graph = SymbolicFlowGraph {
nodes: vec![Node::Fetch {
id: FetchId::new(7),
next_response: Some(NodeId::new(0)),
next_tunnel: None,
collect_body_before: None,
body_limit: 0,
}],
predicates: vec![],
middlewares: vec![],
fetches: vec![],
terminators: vec![],
entries: HashMap::new(),
meta: empty_meta(),
};
assert_err_contains(&graph, "dangling FetchId");
}
#[test]
fn validate_rejects_dangling_next_tunnel() {
let graph = SymbolicFlowGraph {
nodes: vec![Node::Fetch {
id: FetchId::new(0),
next_response: None,
next_tunnel: Some(NodeId::new(42)),
collect_body_before: None,
body_limit: 0,
}],
predicates: vec![],
middlewares: vec![],
fetches: vec![l4_fetch_ref()],
terminators: vec![],
entries: HashMap::new(),
meta: empty_meta(),
};
assert_err_contains(&graph, "next_tunnel dangling");
}
#[test]
fn validate_rejects_dangling_next_in_upgrade() {
let graph = SymbolicFlowGraph {
nodes: vec![Node::Upgrade { next: NodeId::new(42) }],
predicates: vec![],
middlewares: vec![],
fetches: vec![],
terminators: vec![],
entries: HashMap::new(),
meta: empty_meta(),
};
assert_err_contains(&graph, "next dangling");
}
#[test]
fn validate_rejects_http_fetch_with_next_tunnel() {
let graph = SymbolicFlowGraph {
nodes: vec![
Node::Terminate(TerminatorId::new(0)),
Node::Fetch {
id: FetchId::new(0),
next_response: Some(NodeId::new(0)),
next_tunnel: Some(NodeId::new(0)),
collect_body_before: None,
body_limit: 0,
},
],
predicates: vec![],
middlewares: vec![],
fetches: vec![http_fetch_ref()],
terminators: vec![Terminator::WriteHttpResponse],
entries: HashMap::new(),
meta: empty_meta(),
};
assert_err_contains(&graph, "must not have next_tunnel");
}
#[test]
fn validate_rejects_l4_forward_without_next_tunnel() {
let graph = SymbolicFlowGraph {
nodes: vec![Node::Fetch {
id: FetchId::new(0),
next_response: None,
next_tunnel: None,
collect_body_before: None,
body_limit: 0,
}],
predicates: vec![],
middlewares: vec![],
fetches: vec![l4_fetch_ref()],
terminators: vec![],
entries: HashMap::new(),
meta: empty_meta(),
};
assert_err_contains(&graph, "L4Forward requires next_tunnel");
}
#[test]
fn validate_rejects_websocket_upgrade_missing_branch() {
let graph = SymbolicFlowGraph {
nodes: vec![
Node::Terminate(TerminatorId::new(0)),
Node::Fetch {
id: FetchId::new(0),
next_response: None,
next_tunnel: Some(NodeId::new(0)),
collect_body_before: None,
body_limit: 0,
},
],
predicates: vec![],
middlewares: vec![],
fetches: vec![ws_fetch_ref()],
terminators: vec![Terminator::WriteHttpResponse],
entries: HashMap::new(),
meta: empty_meta(),
};
assert_err_contains(&graph, "WebSocketUpgrade requires both");
}
#[test]
fn validate_rejects_bi_outcome_transition_on_non_fetch_node() {
let bad_tid = TerminatorId::new(0);
let mut meta = empty_meta();
meta.short_circuit_response_entry.insert(NodeId::new(1), NodeId::new(0));
let graph = SymbolicFlowGraph {
nodes: vec![Node::Terminate(bad_tid), Node::Upgrade { next: NodeId::new(0) }],
predicates: vec![],
middlewares: vec![],
fetches: vec![],
terminators: vec![Terminator::ByteTunnel],
entries: HashMap::new(),
meta,
};
assert!(check_phases(&graph).is_err());
}
const _: BodySide = BodySide::Request;
}