mod test_utils;
use polaris_graph::executor::{ExecutionError, GraphExecutor};
use polaris_graph::graph::Graph;
use polaris_graph::node::{ContextPolicy, NodeId};
use test_utils::{
DecisionOutput, DecisionSystem, ExecutionLog, FailingSystem, SwitchKeySystem, SwitchOutput,
TrackerNodes, add_tracker, branch, create_test_server, get_hooks,
};
#[derive(Clone, Debug)]
enum Fragment {
Track,
Seq(Vec<Fragment>),
Par(Vec<Fragment>),
Decision {
take_true: bool,
t: Box<Fragment>,
f: Box<Fragment>,
},
Loop { n: usize, body: Box<Fragment> },
Switch {
case: &'static str,
a: Box<Fragment>,
b: Box<Fragment>,
},
Fallible { handler: Box<Fragment> },
Scope { body: Box<Fragment> },
}
impl Fragment {
fn build(&self, g: &mut Graph, trackers: &TrackerNodes) {
match self {
Fragment::Track => {
trackers.add(add_tracker(g));
}
Fragment::Seq(items) => {
for item in items {
item.build(g, trackers);
}
}
Fragment::Par(branches) => {
let branch_fns: Vec<_> = branches
.iter()
.map(|b| {
let b = b.clone();
let trackers = trackers.clone();
branch(move |g| b.build(g, &trackers))
})
.collect();
g.add_parallel("par", branch_fns);
}
Fragment::Decision { take_true, t, f } => {
let t = t.clone();
let f = f.clone();
let t_trackers = trackers.clone();
let f_trackers = trackers.clone();
g.add_boxed_system(Box::new(DecisionSystem {
take_true: *take_true,
}));
g.add_conditional_branch::<DecisionOutput, _, _, _>(
"decision",
|output| output.take_true,
move |g| t.build(g, &t_trackers),
move |g| f.build(g, &f_trackers),
);
}
Fragment::Loop { n, body } => {
let body = body.clone();
let trackers = trackers.clone();
g.add_loop_n("loop", *n, move |g| body.build(g, &trackers));
}
Fragment::Switch { case, a, b } => {
let a = a.clone();
let b = b.clone();
let a_trackers = trackers.clone();
let b_trackers = trackers.clone();
g.add_boxed_system(Box::new(SwitchKeySystem { key: case }));
g.add_switch::<SwitchOutput, _, _, _>(
"switch",
|output| output.key,
[
("a", branch(move |g| a.build(g, &a_trackers))),
("b", branch(move |g| b.build(g, &b_trackers))),
],
None,
);
}
Fragment::Fallible { handler } => {
let handler = handler.clone();
let trackers = trackers.clone();
g.system_boxed(Box::new(FailingSystem))
.on_error(move |g| handler.build(g, &trackers))
.done();
}
Fragment::Scope { body } => {
let mut inner = Graph::new();
body.build(&mut inner, trackers);
g.add_scope("scope", inner, ContextPolicy::shared());
}
}
}
fn is_diverting(&self) -> bool {
match self {
Fragment::Fallible { .. } => true,
Fragment::Seq(items) => items.iter().any(Fragment::is_diverting),
_ => false,
}
}
fn tracker_count(&self) -> usize {
match self {
Fragment::Track => 1,
Fragment::Seq(items) => items.iter().map(Fragment::tracker_count).sum(),
Fragment::Par(branches) => branches.iter().map(Fragment::tracker_count).sum(),
Fragment::Decision { t, f, .. } => t.tracker_count() + f.tracker_count(),
Fragment::Loop { body, .. } | Fragment::Scope { body } => body.tracker_count(),
Fragment::Switch { a, b, .. } => a.tracker_count() + b.tracker_count(),
Fragment::Fallible { handler } => handler.tracker_count(),
}
}
fn zero_counts(&self) -> Vec<usize> {
vec![0; self.tracker_count()]
}
fn predicted_counts(&self) -> Vec<usize> {
self.predicted_counts_inner(1)
}
fn predicted_counts_inner(&self, multiplier: usize) -> Vec<usize> {
match self {
Fragment::Track => vec![multiplier],
Fragment::Seq(items) => {
let mut counts = Vec::new();
let mut diverted = false;
for item in items {
if diverted {
counts.extend(item.zero_counts());
} else {
counts.extend(item.predicted_counts_inner(multiplier));
if item.is_diverting() {
diverted = true;
}
}
}
counts
}
Fragment::Par(branches) => branches
.iter()
.flat_map(|b| b.predicted_counts_inner(multiplier))
.collect(),
Fragment::Decision { take_true, t, f } => {
if *take_true {
let mut counts = t.predicted_counts_inner(multiplier);
counts.extend(f.zero_counts());
counts
} else {
let mut counts = t.zero_counts();
counts.extend(f.predicted_counts_inner(multiplier));
counts
}
}
Fragment::Loop { n, body } => body.predicted_counts_inner(multiplier * n),
Fragment::Switch { case, a, b } => {
if *case == "a" {
let mut counts = a.predicted_counts_inner(multiplier);
counts.extend(b.zero_counts());
counts
} else {
let mut counts = a.zero_counts();
counts.extend(b.predicted_counts_inner(multiplier));
counts
}
}
Fragment::Fallible { handler } => handler.predicted_counts_inner(multiplier),
Fragment::Scope { body } => body.predicted_counts_inner(multiplier),
}
}
}
fn track() -> Fragment {
Fragment::Track
}
fn seq<I: IntoIterator<Item = Fragment>>(items: I) -> Fragment {
Fragment::Seq(items.into_iter().collect())
}
fn par<I: IntoIterator<Item = Fragment>>(branches: I) -> Fragment {
Fragment::Par(branches.into_iter().collect())
}
fn decision(take_true: bool, t: Fragment, f: Fragment) -> Fragment {
Fragment::Decision {
take_true,
t: Box::new(t),
f: Box::new(f),
}
}
fn loop_n(n: usize, body: Fragment) -> Fragment {
Fragment::Loop {
n,
body: Box::new(body),
}
}
fn switch(case: &'static str, a: Fragment, b: Fragment) -> Fragment {
Fragment::Switch {
case,
a: Box::new(a),
b: Box::new(b),
}
}
fn fallible(handler: Fragment) -> Fragment {
Fragment::Fallible {
handler: Box::new(handler),
}
}
fn scope(body: Fragment) -> Fragment {
Fragment::Scope {
body: Box::new(body),
}
}
async fn run_fragment(fragment: Fragment) -> Result<(ExecutionLog, Vec<NodeId>), ExecutionError> {
let mut graph = Graph::new();
let trackers = TrackerNodes::default();
fragment.build(&mut graph, &trackers);
let expected = trackers.into_vec();
let server = create_test_server();
let hooks = get_hooks(&server);
let mut ctx = server.create_context();
let log = ExecutionLog::default();
ctx.insert(log.clone());
GraphExecutor::new()
.execute(&graph, &mut ctx, hooks, None)
.await?;
Ok((log, expected))
}
fn sequence_n(count: usize) -> Fragment {
seq((0..count).map(|_| track()))
}
fn parallel_2() -> Fragment {
par([track(), track()])
}
fn decision_true() -> Fragment {
decision(true, track(), track())
}
fn loop_body_n(iterations: usize) -> Fragment {
loop_n(iterations, track())
}
fn switch_to_a() -> Fragment {
switch("a", track(), track())
}
fn parallel_with_before_after() -> Fragment {
seq([track(), par([track(), track()]), track()])
}
fn decision_with_parallel() -> Fragment {
decision(true, parallel_with_before_after(), track())
}
fn switch_with_parallel() -> Fragment {
switch("a", parallel_with_before_after(), track())
}
fn loop_with_parallel(iterations: usize) -> Fragment {
loop_n(iterations, par([track(), track()]))
}
fn complex_nested() -> Fragment {
decision(
true,
seq([
par([decision_with_parallel(), loop_with_parallel(2)]),
track(), ]),
track(), )
}
#[tokio::test]
async fn test_sequence() {
let (log, nodes) = run_fragment(sequence_n(3)).await.unwrap();
assert_eq!(log.count(&nodes[0]), 1);
assert_eq!(log.count(&nodes[1]), 1);
assert_eq!(log.count(&nodes[2]), 1);
assert_eq!(log.executed(), nodes);
}
#[tokio::test]
async fn test_parallel() {
let (log, nodes) = run_fragment(parallel_2()).await.unwrap();
assert_eq!(log.count(&nodes[0]), 1, "branch a should execute");
assert_eq!(log.count(&nodes[1]), 1, "branch b should execute");
}
#[tokio::test]
async fn test_decision() {
let (log, nodes) = run_fragment(decision_true()).await.unwrap();
assert_eq!(log.count(&nodes[0]), 1, "true branch should execute");
assert_eq!(log.count(&nodes[1]), 0, "false branch should not execute");
}
#[tokio::test]
async fn test_loop() {
let (log, nodes) = run_fragment(loop_body_n(5)).await.unwrap();
assert_eq!(log.count(&nodes[0]), 5, "loop body should execute 5 times");
}
#[tokio::test]
async fn test_switch() {
let (log, nodes) = run_fragment(switch_to_a()).await.unwrap();
assert_eq!(log.count(&nodes[0]), 1, "case a should execute");
assert_eq!(log.count(&nodes[1]), 0, "case b should not execute");
}
#[tokio::test]
async fn test_parallel_converges() {
let (log, nodes) = run_fragment(parallel_with_before_after()).await.unwrap();
assert_eq!(log.count(&nodes[0]), 1, "before should execute");
assert_eq!(log.count(&nodes[1]), 1, "branch a should execute");
assert_eq!(log.count(&nodes[2]), 1, "branch b should execute");
assert_eq!(
log.count(&nodes[3]),
1,
"after should execute (convergence)"
);
}
#[tokio::test]
async fn test_decision_with_nested_parallel() {
let (log, nodes) = run_fragment(decision_with_parallel()).await.unwrap();
assert_eq!(log.count(&nodes[0]), 1, "before parallel should execute");
assert_eq!(log.count(&nodes[1]), 1, "parallel branch a should execute");
assert_eq!(log.count(&nodes[2]), 1, "parallel branch b should execute");
assert_eq!(log.count(&nodes[3]), 1, "after parallel should execute");
assert_eq!(log.count(&nodes[4]), 0, "false branch should not execute");
}
#[tokio::test]
async fn test_switch_with_nested_parallel() {
let (log, nodes) = run_fragment(switch_with_parallel()).await.unwrap();
assert_eq!(log.count(&nodes[0]), 1, "before parallel should execute");
assert_eq!(log.count(&nodes[1]), 1, "parallel branch a should execute");
assert_eq!(log.count(&nodes[2]), 1, "parallel branch b should execute");
assert_eq!(log.count(&nodes[3]), 1, "after parallel should execute");
assert_eq!(log.count(&nodes[4]), 0, "case b should not execute");
}
#[tokio::test]
async fn test_loop_with_nested_parallel() {
let (log, nodes) = run_fragment(loop_with_parallel(3)).await.unwrap();
assert_eq!(
log.count(&nodes[0]),
3,
"parallel branch a should execute 3 times"
);
assert_eq!(
log.count(&nodes[1]),
3,
"parallel branch b should execute 3 times"
);
}
#[tokio::test]
async fn test_complex_nested_composition() {
let (log, nodes) = run_fragment(complex_nested()).await.unwrap();
assert_eq!(log.count(&nodes[0]), 1, "inner decision: before parallel");
assert_eq!(log.count(&nodes[1]), 1, "inner decision: parallel branch a");
assert_eq!(log.count(&nodes[2]), 1, "inner decision: parallel branch b");
assert_eq!(log.count(&nodes[3]), 1, "inner decision: after parallel");
assert_eq!(
log.count(&nodes[4]),
0,
"inner decision: false branch not taken"
);
assert_eq!(
log.count(&nodes[5]),
2,
"loop parallel branch a (2 iterations)"
);
assert_eq!(
log.count(&nodes[6]),
2,
"loop parallel branch b (2 iterations)"
);
assert_eq!(log.count(&nodes[7]), 1, "after outer parallel");
assert_eq!(log.count(&nodes[8]), 0, "outer false branch not taken");
}
#[tokio::test]
async fn test_fallible_handler_executes() {
let (log, nodes) = run_fragment(fallible(track())).await.unwrap();
assert_eq!(log.count(&nodes[0]), 1, "error handler tracker should fire");
}
#[tokio::test]
async fn test_fallible_in_sequence() {
let fragment = seq([track(), fallible(track()), track()]);
let (log, nodes) = run_fragment(fragment).await.unwrap();
assert_eq!(log.count(&nodes[0]), 1, "before should execute");
assert_eq!(log.count(&nodes[1]), 1, "error handler should execute");
assert_eq!(
log.count(&nodes[2]),
0,
"after should not execute (unreachable after error diversion)"
);
}
#[tokio::test]
async fn test_fallible_in_loop() {
let fragment = loop_n(3, fallible(track()));
let (log, nodes) = run_fragment(fragment).await.unwrap();
assert_eq!(log.count(&nodes[0]), 3, "error handler should fire 3 times");
}
#[tokio::test]
async fn test_fallible_with_seq_handler() {
let fragment = fallible(seq([track(), track()]));
let (log, nodes) = run_fragment(fragment).await.unwrap();
assert_eq!(log.count(&nodes[0]), 1, "handler step 1 should fire");
assert_eq!(log.count(&nodes[1]), 1, "handler step 2 should fire");
}
#[tokio::test]
async fn test_arbitrary_depth_nesting() {
const DEPTH: usize = 5;
fn nested_decisions(depth: usize) -> Fragment {
if depth == 0 {
track()
} else {
decision(
true,
par([nested_decisions(depth - 1), nested_decisions(depth - 1)]),
seq([]), )
}
}
let result = run_fragment(nested_decisions(DEPTH)).await;
assert!(
result.is_ok(),
"Depth {} nesting failed: {:?}",
DEPTH,
result.err()
);
}
#[tokio::test]
async fn test_scope() {
let (log, nodes) = run_fragment(scope(track())).await.unwrap();
assert_eq!(log.count(&nodes[0]), 1, "scope body should execute");
}
#[tokio::test]
async fn test_scope_in_sequence() {
let fragment = seq([track(), scope(par([track(), track()])), track()]);
let (log, nodes) = run_fragment(fragment).await.unwrap();
assert_eq!(log.count(&nodes[0]), 1, "before scope");
assert_eq!(log.count(&nodes[1]), 1, "scope parallel branch a");
assert_eq!(log.count(&nodes[2]), 1, "scope parallel branch b");
assert_eq!(log.count(&nodes[3]), 1, "after scope");
}
#[tokio::test]
async fn test_scope_in_loop() {
let fragment = loop_n(3, scope(track()));
let (log, nodes) = run_fragment(fragment).await.unwrap();
assert_eq!(log.count(&nodes[0]), 3, "scope body should execute 3 times");
}
#[tokio::test]
async fn test_arbitrary_composition() {
let complex = seq([
track(),
par([
decision(true, par([track(), track()]), track()),
loop_n(2, par([track(), track()])),
]),
track(),
]);
let (log, nodes) = run_fragment(complex).await.unwrap();
assert_eq!(log.count(&nodes[0]), 1, "initial track");
assert_eq!(log.count(&nodes[1]), 1, "decision true branch a");
assert_eq!(log.count(&nodes[2]), 1, "decision true branch b");
assert_eq!(log.count(&nodes[3]), 0, "decision false branch not taken");
assert_eq!(log.count(&nodes[4]), 2, "loop branch a (2 iterations)");
assert_eq!(log.count(&nodes[5]), 2, "loop branch b (2 iterations)");
assert_eq!(log.count(&nodes[6]), 1, "final track");
}
mod prop_tests {
use super::*;
use proptest::prelude::*;
fn arb_fragment(depth: u32) -> BoxedStrategy<Fragment> {
if depth == 0 {
Just(Fragment::Track).boxed()
} else {
prop_oneof![
prop::collection::vec(arb_fragment(depth - 1), 1..=3usize).prop_map(Fragment::Seq),
prop::collection::vec(arb_fragment(depth - 1), 2..=4usize).prop_map(Fragment::Par),
(
any::<bool>(),
arb_fragment(depth - 1),
arb_fragment(depth - 1)
)
.prop_map(|(b, t, f)| decision(b, t, f)),
(1..=4usize, arb_fragment(depth - 1)).prop_map(|(n, body)| loop_n(n, body)),
(
prop_oneof![Just("a"), Just("b")],
arb_fragment(depth - 1),
arb_fragment(depth - 1),
)
.prop_map(|(c, a, b)| switch(c, a, b)),
arb_fragment(depth - 1).prop_map(fallible),
arb_fragment(depth - 1).prop_map(scope),
]
.boxed()
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(256))]
#[test]
fn prop_per_node_execution_matches_prediction(fragment in arb_fragment(3)) {
let rt = tokio::runtime::Runtime::new().expect("tokio runtime");
rt.block_on(async {
let expected = fragment.predicted_counts();
let (log, nodes) = run_fragment(fragment).await.expect("execution");
assert_eq!(nodes.len(), expected.len(), "tracker count mismatch");
for (i, (node, expected_count)) in nodes.iter().zip(&expected).enumerate() {
prop_assert_eq!(
log.count(node),
*expected_count,
"tracker[{}]", i
);
}
Ok(())
})?;
}
}
}