#![allow(clippy::pedantic, clippy::nursery)]
use crate::combinator::join::join2_outcomes;
use crate::combinator::race::{RaceWinner, race2_outcomes};
use crate::types::Outcome;
use crate::types::cancel::CancelReason;
use crate::types::outcome::{PanicPayload, Severity};
#[derive(Debug, Clone, PartialEq, Eq)]
struct TestErr(u8);
fn rep(sev: Severity, tag: u8) -> Outcome<u32, TestErr> {
match sev {
Severity::Ok => Outcome::Ok(tag as u32),
Severity::Err => Outcome::Err(TestErr(tag)),
Severity::Cancelled => Outcome::Cancelled(CancelReason::race_loser()),
Severity::Panicked => Outcome::Panicked(PanicPayload::new(format!("panic-{tag}"))),
}
}
const LATTICE: [Severity; 4] = [
Severity::Ok,
Severity::Err,
Severity::Cancelled,
Severity::Panicked,
];
fn lhs_severity(a: Severity, b: Severity, c: Severity, winner: RaceWinner) -> Severity {
let (left, _, _) = join2_outcomes::<u32, u32, TestErr>(rep(a, 1), rep(b, 2));
let (right, _, _) = join2_outcomes::<u32, u32, TestErr>(rep(a, 3), rep(c, 4));
let (w, _which, l) =
race2_outcomes::<(u32, u32), TestErr>(winner, wrap_join(left), wrap_join(right));
max_sev(w.severity(), l.severity())
}
fn rhs_severity(a: Severity, b: Severity, c: Severity, winner: RaceWinner) -> Severity {
let (w, _which, l) = race2_outcomes::<u32, TestErr>(winner, rep(b, 2), rep(c, 3));
let inner_race_agg = worst_outcome(w, l);
let (joined, _, _) = join2_outcomes::<u32, u32, TestErr>(rep(a, 1), inner_race_agg);
joined.severity()
}
fn wrap_join(o: Outcome<(u32, u32), TestErr>) -> Outcome<(u32, u32), TestErr> {
o
}
fn worst_outcome<T, E>(a: Outcome<T, E>, b: Outcome<T, E>) -> Outcome<T, E> {
if b.severity() > a.severity() { b } else { a }
}
fn max_sev(a: Severity, b: Severity) -> Severity {
if a >= b { a } else { b }
}
#[cfg(test)]
mod mr_severity {
use super::*;
#[test]
fn law_race_join_dist_severity_first_winner() {
for &a in &LATTICE {
for &b in &LATTICE {
for &c in &LATTICE {
let lhs = lhs_severity(a, b, c, RaceWinner::First);
let rhs = rhs_severity(a, b, c, RaceWinner::First);
assert_eq!(
lhs, rhs,
"LAW-RACE-JOIN-DIST severity mismatch (winner=First): \
race(join({a:?},{b:?}),join({a:?},{c:?})) = {lhs:?}, \
join({a:?},race({b:?},{c:?})) = {rhs:?}",
);
}
}
}
}
#[test]
fn law_race_join_dist_severity_second_winner() {
for &a in &LATTICE {
for &b in &LATTICE {
for &c in &LATTICE {
let lhs = lhs_severity(a, b, c, RaceWinner::Second);
let rhs = rhs_severity(a, b, c, RaceWinner::Second);
assert_eq!(
lhs, rhs,
"LAW-RACE-JOIN-DIST severity mismatch (winner=Second): \
race(join({a:?},{b:?}),join({a:?},{c:?})) = {lhs:?}, \
join({a:?},race({b:?},{c:?})) = {rhs:?}",
);
}
}
}
}
#[test]
fn law_race_join_dist_severity_winner_invariant() {
for &a in &LATTICE {
for &b in &LATTICE {
for &c in &LATTICE {
let first = lhs_severity(a, b, c, RaceWinner::First);
let second = lhs_severity(a, b, c, RaceWinner::Second);
assert_eq!(
first, second,
"LHS severity is winner-dependent for (a={a:?}, b={b:?}, c={c:?})",
);
let first_r = rhs_severity(a, b, c, RaceWinner::First);
let second_r = rhs_severity(a, b, c, RaceWinner::Second);
assert_eq!(
first_r, second_r,
"RHS severity is winner-dependent for (a={a:?}, b={b:?}, c={c:?})",
);
}
}
}
}
#[test]
fn unit_preservation() {
assert_eq!(
lhs_severity(Severity::Ok, Severity::Ok, Severity::Ok, RaceWinner::First),
Severity::Ok
);
assert_eq!(
rhs_severity(Severity::Ok, Severity::Ok, Severity::Ok, RaceWinner::First),
Severity::Ok
);
}
#[test]
fn panic_absorbs_both_sides() {
for &b in &LATTICE {
for &c in &LATTICE {
let lhs = lhs_severity(Severity::Panicked, b, c, RaceWinner::First);
let rhs = rhs_severity(Severity::Panicked, b, c, RaceWinner::First);
assert_eq!(lhs, Severity::Panicked);
assert_eq!(rhs, Severity::Panicked);
}
}
}
}
#[cfg(test)]
mod mr_run_once {
use super::*;
use std::cell::Cell;
struct A<'c> {
polls: &'c Cell<u32>,
}
impl<'c> A<'c> {
fn new(polls: &'c Cell<u32>) -> Self {
Self { polls }
}
fn run(&self) -> Outcome<u32, TestErr> {
self.polls.set(self.polls.get() + 1);
Outcome::Ok(7)
}
}
fn run_lhs(polls: &Cell<u32>, b: Outcome<u32, TestErr>, c: Outcome<u32, TestErr>) -> Severity {
let a1 = A::new(polls).run();
let a2 = A::new(polls).run();
let (left, _, _) = join2_outcomes::<u32, u32, TestErr>(a1, b);
let (right, _, _) = join2_outcomes::<u32, u32, TestErr>(a2, c);
let (w, _, l) = race2_outcomes::<(u32, u32), TestErr>(RaceWinner::First, left, right);
max_sev(w.severity(), l.severity())
}
fn run_rhs(polls: &Cell<u32>, b: Outcome<u32, TestErr>, c: Outcome<u32, TestErr>) -> Severity {
let a = A::new(polls).run();
let (w, _, l) = race2_outcomes::<u32, TestErr>(RaceWinner::First, b, c);
let inner = worst_outcome(w, l);
let (joined, _, _) = join2_outcomes::<u32, u32, TestErr>(a, inner);
joined.severity()
}
#[test]
fn rhs_runs_a_exactly_once() {
let polls = Cell::new(0);
let _ = run_rhs(&polls, Outcome::Ok(1), Outcome::Ok(2));
assert_eq!(
polls.get(),
1,
"RHS shape `join(a, race(b,c))` must poll `a` exactly once \
— the rewrite engine relies on this to collapse duplicate work."
);
}
#[test]
fn lhs_runs_a_twice_as_baseline() {
let polls = Cell::new(0);
let _ = run_lhs(&polls, Outcome::Ok(1), Outcome::Ok(2));
assert_eq!(
polls.get(),
2,
"LHS shape `race(join(a,b), join(a,c))` literally contains two \
independent `a` sub-trees and therefore polls twice — this is \
the motivation for the LAW-RACE-JOIN-DIST rewrite.",
);
}
#[test]
fn rewrite_saves_one_a_execution() {
let inputs: &[(Outcome<u32, TestErr>, Outcome<u32, TestErr>)] = &[
(Outcome::Ok(10), Outcome::Ok(20)),
(Outcome::Ok(10), Outcome::Err(TestErr(2))),
(Outcome::Err(TestErr(1)), Outcome::Ok(20)),
(Outcome::Err(TestErr(1)), Outcome::Err(TestErr(2))),
];
for (b, c) in inputs {
{
let lhs_polls = Cell::new(0);
let rhs_polls = Cell::new(0);
let _ = run_lhs(&lhs_polls, b.clone(), c.clone());
let _ = run_rhs(&rhs_polls, b.clone(), c.clone());
assert_eq!(
lhs_polls.get() - rhs_polls.get(),
1,
"expected exactly one saved `a` execution per rewrite",
);
}
}
}
}
#[cfg(test)]
mod mr_mask {
use super::*;
#[test]
fn rhs_severity_bounded_by_input_lattice_join() {
for &a in &LATTICE {
for &b in &LATTICE {
for &c in &LATTICE {
let bound = max_sev(a, max_sev(b, c));
let got_first = rhs_severity(a, b, c, RaceWinner::First);
let got_second = rhs_severity(a, b, c, RaceWinner::Second);
assert!(
got_first <= bound,
"RHS severity {got_first:?} exceeds lattice bound {bound:?} for (a={a:?}, b={b:?}, c={c:?})"
);
assert!(
got_second <= bound,
"RHS severity {got_second:?} exceeds lattice bound {bound:?} for (a={a:?}, b={b:?}, c={c:?})"
);
}
}
}
}
#[test]
fn lhs_severity_bounded_by_input_lattice_join() {
for &a in &LATTICE {
for &b in &LATTICE {
for &c in &LATTICE {
let bound = max_sev(a, max_sev(b, c));
let got = lhs_severity(a, b, c, RaceWinner::First);
assert!(
got <= bound,
"LHS severity {got:?} exceeds lattice bound {bound:?} for (a={a:?}, b={b:?}, c={c:?})"
);
}
}
}
}
}