use std::cell::{Cell, RefCell};
use std::collections::{HashSet, VecDeque};
use std::future::Future;
use std::pin::Pin;
use std::rc::Rc;
use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker};
use crate::net::{NodeId, Topology};
use crate::queue::EventQueue;
use crate::time::{Duration, Time};
fn pair_key(a: NodeId, b: NodeId) -> (NodeId, NodeId) {
if a.as_u32() <= b.as_u32() {
(a, b)
} else {
(b, a)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct TaskId(pub u32);
#[derive(Debug, Clone)]
pub struct Envelope<M> {
pub src: NodeId,
pub dst: NodeId,
pub payload: M,
pub sent_at: Time,
pub received_at: Time,
}
#[derive(Debug)]
enum TaskEvent<M> {
Wake(TaskId),
Deliver {
src: NodeId,
dst: NodeId,
payload: M,
sent_at: Time,
},
}
struct NodeState<M> {
inbox: VecDeque<Envelope<M>>,
waiting_for_message: bool,
task_id: TaskId,
}
struct SimState<M> {
now: Time,
topology: Topology,
nodes: Vec<NodeState<M>>,
events: EventQueue<TaskEvent<M>>,
ready_tasks: Vec<TaskId>,
cancelled: Rc<Cell<bool>>,
partitions: HashSet<(NodeId, NodeId)>,
messages_dropped: u64,
}
impl<M> SimState<M> {
fn schedule_wake(&mut self, time: Time, task_id: TaskId) {
self.events.schedule(time, TaskEvent::Wake(task_id));
}
}
#[derive(Clone)]
pub struct CancellationToken {
flag: Rc<Cell<bool>>,
}
impl CancellationToken {
pub fn is_cancelled(&self) -> bool {
self.flag.get()
}
pub fn cancel(&self) {
self.flag.set(true);
}
}
impl std::fmt::Debug for CancellationToken {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CancellationToken")
.field("cancelled", &self.flag.get())
.finish()
}
}
pub struct NodeContext<M: 'static> {
state: Rc<RefCell<SimState<M>>>,
node_id: NodeId,
task_id: TaskId,
}
impl<M: 'static> NodeContext<M> {
pub fn id(&self) -> NodeId {
self.node_id
}
pub fn now(&self) -> Time {
self.state.borrow().now
}
pub fn shutdown_token(&self) -> CancellationToken {
CancellationToken {
flag: Rc::clone(&self.state.borrow().cancelled),
}
}
pub fn partition(&self, a: NodeId, b: NodeId) {
self.state.borrow_mut().partitions.insert(pair_key(a, b));
}
pub fn heal(&self, a: NodeId, b: NodeId) {
self.state.borrow_mut().partitions.remove(&pair_key(a, b));
}
pub fn is_partitioned(&self, a: NodeId, b: NodeId) -> bool {
self.state.borrow().partitions.contains(&pair_key(a, b))
}
pub fn fail_link(&self, a: NodeId, b: NodeId) {
self.state.borrow_mut().topology.fail_link(a, b);
}
pub fn heal_link(&self, a: NodeId, b: NodeId) {
self.state.borrow_mut().topology.heal_link(a, b);
}
pub fn is_link_failed(&self, a: NodeId, b: NodeId) -> bool {
self.state.borrow().topology.is_link_failed(a, b)
}
pub fn fail_link_directed(&self, src: NodeId, dst: NodeId) {
self.state
.borrow_mut()
.topology
.fail_link_directed(src, dst);
}
pub fn heal_link_directed(&self, src: NodeId, dst: NodeId) {
self.state
.borrow_mut()
.topology
.heal_link_directed(src, dst);
}
pub fn is_link_failed_directed(&self, src: NodeId, dst: NodeId) -> bool {
self.state
.borrow()
.topology
.is_link_failed_directed(src, dst)
}
pub fn fail_node(&self, node: NodeId) {
self.state.borrow_mut().topology.fail_node(node);
}
pub fn heal_node(&self, node: NodeId) {
self.state.borrow_mut().topology.heal_node(node);
}
pub fn is_node_failed(&self, node: NodeId) -> bool {
self.state.borrow().topology.is_node_failed(node)
}
pub fn sleep(&self, duration: Duration) -> Sleep<M> {
let wake_time = {
let state = self.state.borrow();
state.now + duration
};
Sleep {
state: Rc::clone(&self.state),
task_id: self.task_id,
wake_time,
scheduled: false,
}
}
pub fn recv(&self) -> Recv<M> {
Recv {
state: Rc::clone(&self.state),
node_id: self.node_id,
}
}
pub fn recv_timeout(&self, timeout: Duration) -> RecvTimeout<M> {
let deadline = self.state.borrow().now + timeout;
RecvTimeout {
state: Rc::clone(&self.state),
node_id: self.node_id,
task_id: self.task_id,
deadline,
wake_scheduled: false,
}
}
}
impl<M: Clone + 'static> NodeContext<M> {
pub fn send(&self, dst: NodeId, payload: M) -> SendFut<M> {
SendFut {
state: Rc::clone(&self.state),
src: self.node_id,
dst,
payload: Some(payload),
}
}
}
pub struct Sleep<M: 'static> {
state: Rc<RefCell<SimState<M>>>,
task_id: TaskId,
wake_time: Time,
scheduled: bool,
}
impl<M: 'static> Future for Sleep<M> {
type Output = ();
fn poll(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
let now = self.state.borrow().now;
if now >= self.wake_time {
Poll::Ready(())
} else {
if !self.scheduled {
self.state
.borrow_mut()
.schedule_wake(self.wake_time, self.task_id);
self.scheduled = true;
}
Poll::Pending
}
}
}
pub struct Recv<M: 'static> {
state: Rc<RefCell<SimState<M>>>,
node_id: NodeId,
}
impl<M: 'static> Future for Recv<M> {
type Output = Envelope<M>;
fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut state = self.state.borrow_mut();
let node = &mut state.nodes[self.node_id.as_usize()];
if let Some(envelope) = node.inbox.pop_front() {
node.waiting_for_message = false;
Poll::Ready(envelope)
} else {
node.waiting_for_message = true;
Poll::Pending
}
}
}
pub struct RecvTimeout<M: 'static> {
state: Rc<RefCell<SimState<M>>>,
node_id: NodeId,
task_id: TaskId,
deadline: Time,
wake_scheduled: bool,
}
impl<M: 'static> Future for RecvTimeout<M> {
type Output = Option<Envelope<M>>;
fn poll(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
let node_idx = self.node_id.as_usize();
let deadline = self.deadline;
let task_id = self.task_id;
let need_schedule = !self.wake_scheduled;
let outcome = {
let mut state = self.state.borrow_mut();
if let Some(envelope) = state.nodes[node_idx].inbox.pop_front() {
state.nodes[node_idx].waiting_for_message = false;
return Poll::Ready(Some(envelope));
}
if state.now >= deadline {
state.nodes[node_idx].waiting_for_message = false;
return Poll::Ready(None);
}
state.nodes[node_idx].waiting_for_message = true;
if need_schedule {
state.schedule_wake(deadline, task_id);
}
Poll::Pending
};
if need_schedule {
self.wake_scheduled = true;
}
outcome
}
}
pub struct SendFut<M: 'static> {
state: Rc<RefCell<SimState<M>>>,
src: NodeId,
dst: NodeId,
payload: Option<M>,
}
impl<M: Clone + 'static> Future for SendFut<M> {
type Output = ();
fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = unsafe { self.get_unchecked_mut() };
if let Some(payload) = this.payload.take() {
let mut state = this.state.borrow_mut();
let now = state.now;
if state.partitions.contains(&pair_key(this.src, this.dst)) {
state.messages_dropped += 1;
return Poll::Ready(());
}
let route = match state.topology.route(this.src, this.dst) {
Some(r) => r,
None => {
state.messages_dropped += 1;
return Poll::Ready(());
}
};
let delivery_time = now + route.total_latency;
state.events.schedule(
delivery_time,
TaskEvent::Deliver {
src: this.src,
dst: this.dst,
payload,
sent_at: now,
},
);
}
Poll::Ready(())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Either<A, B> {
Left(A),
Right(B),
}
pub fn select2<A, B>(a: A, b: B) -> Select2<A, B>
where
A: Future,
B: Future,
{
Select2 { a, b }
}
pub struct Select2<A, B> {
a: A,
b: B,
}
impl<A, B> Future for Select2<A, B>
where
A: Future,
B: Future,
{
type Output = Either<A::Output, B::Output>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = unsafe { self.get_unchecked_mut() };
let a = unsafe { Pin::new_unchecked(&mut this.a) };
if let Poll::Ready(v) = a.poll(cx) {
return Poll::Ready(Either::Left(v));
}
let b = unsafe { Pin::new_unchecked(&mut this.b) };
if let Poll::Ready(v) = b.poll(cx) {
return Poll::Ready(Either::Right(v));
}
Poll::Pending
}
}
struct Task {
future: Pin<Box<dyn Future<Output = ()>>>,
completed: bool,
}
fn create_waker() -> Waker {
fn clone_fn(data: *const ()) -> RawWaker {
RawWaker::new(data, &VTABLE)
}
fn wake_fn(_data: *const ()) {}
fn wake_by_ref_fn(_data: *const ()) {}
fn drop_fn(_data: *const ()) {}
static VTABLE: RawWakerVTable = RawWakerVTable::new(clone_fn, wake_fn, wake_by_ref_fn, drop_fn);
let raw_waker = RawWaker::new(std::ptr::null(), &VTABLE);
unsafe { Waker::from_raw(raw_waker) }
}
pub struct TaskSimBuilder<M: 'static> {
topology: Topology,
seed: u64,
_phantom: std::marker::PhantomData<M>,
}
impl<M: 'static> TaskSimBuilder<M> {
pub fn new(topology: Topology, seed: u64) -> Self {
TaskSimBuilder {
topology,
seed,
_phantom: std::marker::PhantomData,
}
}
pub fn build<F, Fut>(self, node_fn: F) -> TaskSim<M>
where
F: Fn(NodeContext<M>) -> Fut,
Fut: Future<Output = ()> + 'static,
M: Clone,
{
let node_count = self.topology.node_count();
let state = Rc::new(RefCell::new(SimState {
now: Time::ZERO,
topology: self.topology,
nodes: (0..node_count)
.map(|i| NodeState {
inbox: VecDeque::new(),
waiting_for_message: false,
task_id: TaskId(i as u32),
})
.collect(),
events: EventQueue::new(),
ready_tasks: Vec::new(),
cancelled: Rc::new(Cell::new(false)),
partitions: HashSet::new(),
messages_dropped: 0,
}));
let mut tasks = Vec::with_capacity(node_count);
for i in 0..node_count {
let task_id = TaskId(i as u32);
let node_id = NodeId::new(i as u32);
let ctx = NodeContext {
state: Rc::clone(&state),
node_id,
task_id,
};
let future = node_fn(ctx);
tasks.push(Task {
future: Box::pin(future),
completed: false,
});
}
{
let mut s = state.borrow_mut();
for i in 0..node_count {
s.ready_tasks.push(TaskId(i as u32));
}
}
TaskSim {
state,
tasks,
seed: self.seed,
ready_scratch: Vec::new(),
}
}
}
#[derive(Debug, Clone, Default)]
pub struct TaskSimStats {
pub final_time: Time,
pub events_processed: u64,
pub messages_delivered: u64,
pub task_polls: u64,
pub tasks_completed: u64,
pub messages_dropped: u64,
}
pub struct TaskSim<M: 'static> {
state: Rc<RefCell<SimState<M>>>,
tasks: Vec<Task>,
seed: u64,
ready_scratch: Vec<TaskId>,
}
impl<M: Clone + 'static> TaskSim<M> {
pub fn inject(&mut self, src: NodeId, dst: NodeId, payload: M) {
let mut state = self.state.borrow_mut();
let now = state.now;
if state.partitions.contains(&pair_key(src, dst)) {
state.messages_dropped += 1;
return;
}
let route = match state.topology.route(src, dst) {
Some(r) => r,
None => {
state.messages_dropped += 1;
return;
}
};
let delivery_time = now + route.total_latency;
state.events.schedule(
delivery_time,
TaskEvent::Deliver {
src,
dst,
payload,
sent_at: now,
},
);
}
pub fn seed(&self) -> u64 {
self.seed
}
pub fn partition(&mut self, a: NodeId, b: NodeId) {
self.state.borrow_mut().partitions.insert(pair_key(a, b));
}
pub fn heal(&mut self, a: NodeId, b: NodeId) {
self.state.borrow_mut().partitions.remove(&pair_key(a, b));
}
pub fn is_partitioned(&self, a: NodeId, b: NodeId) -> bool {
self.state.borrow().partitions.contains(&pair_key(a, b))
}
pub fn fail_link(&mut self, a: NodeId, b: NodeId) {
self.state.borrow_mut().topology.fail_link(a, b);
}
pub fn heal_link(&mut self, a: NodeId, b: NodeId) {
self.state.borrow_mut().topology.heal_link(a, b);
}
pub fn is_link_failed(&self, a: NodeId, b: NodeId) -> bool {
self.state.borrow().topology.is_link_failed(a, b)
}
pub fn fail_link_directed(&mut self, src: NodeId, dst: NodeId) {
self.state
.borrow_mut()
.topology
.fail_link_directed(src, dst);
}
pub fn heal_link_directed(&mut self, src: NodeId, dst: NodeId) {
self.state
.borrow_mut()
.topology
.heal_link_directed(src, dst);
}
pub fn is_link_failed_directed(&self, src: NodeId, dst: NodeId) -> bool {
self.state
.borrow()
.topology
.is_link_failed_directed(src, dst)
}
pub fn fail_node(&mut self, node: NodeId) {
self.state.borrow_mut().topology.fail_node(node);
}
pub fn heal_node(&mut self, node: NodeId) {
self.state.borrow_mut().topology.heal_node(node);
}
pub fn is_node_failed(&self, node: NodeId) -> bool {
self.state.borrow().topology.is_node_failed(node)
}
pub fn shutdown_token(&self) -> CancellationToken {
CancellationToken {
flag: Rc::clone(&self.state.borrow().cancelled),
}
}
pub fn run(self) -> TaskSimStats {
self.run_until(|_| true)
}
pub fn run_until<F>(mut self, mut continue_fn: F) -> TaskSimStats
where
F: FnMut(Time) -> bool,
{
let waker = create_waker();
let mut cx = Context::from_waker(&waker);
let mut stats = TaskSimStats::default();
loop {
self.ready_scratch.clear();
{
let mut state = self.state.borrow_mut();
std::mem::swap(&mut self.ready_scratch, &mut state.ready_tasks);
}
for &task_id in &self.ready_scratch {
if let Some(task) = self.tasks.get_mut(task_id.0 as usize) {
if task.completed {
continue;
}
stats.task_polls += 1;
if let Poll::Ready(()) = task.future.as_mut().poll(&mut cx) {
task.completed = true;
stats.tasks_completed += 1;
}
}
}
let next_event = self.state.borrow_mut().events.pop();
match next_event {
Some(scheduled) => {
if !continue_fn(scheduled.time) {
break;
}
{
let mut state = self.state.borrow_mut();
state.now = scheduled.time;
}
stats.events_processed += 1;
match scheduled.event {
TaskEvent::Wake(task_id) => {
let mut state = self.state.borrow_mut();
state.ready_tasks.push(task_id);
}
TaskEvent::Deliver {
src,
dst,
payload,
sent_at,
} => {
let mut state = self.state.borrow_mut();
let now = state.now;
if dst.as_usize() < state.nodes.len() {
let task_id = state.nodes[dst.as_usize()].task_id;
let waiting = state.nodes[dst.as_usize()].waiting_for_message;
state.nodes[dst.as_usize()].inbox.push_back(Envelope {
src,
dst,
payload,
sent_at,
received_at: now,
});
if waiting {
state.ready_tasks.push(task_id);
}
}
stats.messages_delivered += 1;
}
}
}
None => {
let has_ready = !self.state.borrow().ready_tasks.is_empty();
if !has_ready {
break;
}
}
}
}
let state = self.state.borrow();
stats.final_time = state.now;
stats.messages_dropped = state.messages_dropped;
stats
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::net::TopologyBuilder;
#[test]
fn basic_sleep() {
let topology = TopologyBuilder::new(1).build();
let sim = TaskSimBuilder::<()>::new(topology, 42).build(|ctx| async move {
ctx.sleep(Duration::from_millis(100)).await;
});
let stats = sim.run();
assert_eq!(stats.final_time, Time::from_millis(100));
}
#[test]
fn multiple_sleeps() {
let topology = TopologyBuilder::new(1).build();
let sim = TaskSimBuilder::<()>::new(topology, 42).build(|ctx| async move {
ctx.sleep(Duration::from_millis(50)).await;
ctx.sleep(Duration::from_millis(50)).await;
ctx.sleep(Duration::from_millis(50)).await;
});
let stats = sim.run();
assert_eq!(stats.final_time, Time::from_millis(150));
}
#[test]
fn send_recv() {
let topology = TopologyBuilder::new(2)
.link(0u32, 1u32, Duration::from_millis(10))
.build();
let sim = TaskSimBuilder::<String>::new(topology, 42).build(|ctx| async move {
match ctx.id().as_u32() {
0 => {
ctx.send(NodeId(1), "hello".to_string()).await;
}
1 => {
let msg = ctx.recv().await;
assert_eq!(msg.payload, "hello");
assert_eq!(msg.src, NodeId(0));
}
_ => {}
}
});
let stats = sim.run();
assert_eq!(stats.messages_delivered, 1);
}
#[test]
fn ping_pong() {
let topology = TopologyBuilder::new(2)
.link(0u32, 1u32, Duration::from_millis(10))
.build();
let mut sim = TaskSimBuilder::<String>::new(topology, 42).build(|ctx| async move {
let id = ctx.id().as_u32();
if id == 0 {
let msg = ctx.recv().await;
ctx.send(NodeId(1), format!("ping from 0, got: {}", msg.payload))
.await;
let reply = ctx.recv().await;
assert!(reply.payload.starts_with("pong"));
} else {
let msg = ctx.recv().await;
assert!(msg.payload.starts_with("ping"));
ctx.send(msg.src, "pong from 1".to_string()).await;
}
});
sim.inject(NodeId(1), NodeId(0), "start".to_string());
let stats = sim.run();
assert_eq!(stats.messages_delivered, 3);
}
#[test]
fn concurrent_nodes() {
let topology = TopologyBuilder::new(3)
.full_mesh(Duration::from_millis(5))
.build();
let sim = TaskSimBuilder::<u32>::new(topology, 42).build(|ctx| async move {
let id = ctx.id().as_u32();
ctx.sleep(Duration::from_millis((id as u64 + 1) * 10)).await;
let next = NodeId((id + 1) % 3);
ctx.send(next, id).await;
let _msg = ctx.recv().await;
});
let stats = sim.run();
assert_eq!(stats.messages_delivered, 3);
}
#[test]
fn deterministic_replay() {
fn run_sim() -> (Time, u64) {
let topology = TopologyBuilder::new(2)
.link(0u32, 1u32, Duration::from_millis(10))
.build();
let sim = TaskSimBuilder::<String>::new(topology, 42).build(|ctx| async move {
if ctx.id().as_u32() == 0 {
ctx.sleep(Duration::from_millis(5)).await;
ctx.send(NodeId(1), "test".to_string()).await;
} else {
let _ = ctx.recv().await;
}
});
let stats = sim.run();
(stats.final_time, stats.messages_delivered)
}
let (t1, m1) = run_sim();
let (t2, m2) = run_sim();
assert_eq!(t1, t2);
assert_eq!(m1, m2);
}
#[test]
fn now_updates() {
let topology = TopologyBuilder::new(1).build();
let sim = TaskSimBuilder::<()>::new(topology, 42).build(|ctx| async move {
assert_eq!(ctx.now(), Time::ZERO);
ctx.sleep(Duration::from_millis(100)).await;
assert_eq!(ctx.now(), Time::from_millis(100));
ctx.sleep(Duration::from_millis(50)).await;
assert_eq!(ctx.now(), Time::from_millis(150));
});
sim.run();
}
#[test]
fn self_send() {
let topology = TopologyBuilder::new(1).build();
let sim = TaskSimBuilder::<String>::new(topology, 42).build(|ctx| async move {
ctx.send(ctx.id(), "hello self".to_string()).await;
let msg = ctx.recv().await;
assert_eq!(msg.src, ctx.id());
assert_eq!(msg.dst, ctx.id());
assert_eq!(msg.payload, "hello self");
});
let stats = sim.run();
assert_eq!(stats.messages_delivered, 1);
}
#[test]
fn multiple_queued_messages() {
let topology = TopologyBuilder::new(2)
.link(0u32, 1u32, Duration::from_millis(10))
.build();
let mut sim = TaskSimBuilder::<u32>::new(topology, 42).build(|ctx| async move {
if ctx.id().as_u32() == 1 {
let m1 = ctx.recv().await;
let m2 = ctx.recv().await;
let m3 = ctx.recv().await;
assert_eq!(m1.payload, 1);
assert_eq!(m2.payload, 2);
assert_eq!(m3.payload, 3);
}
});
sim.inject(NodeId(0), NodeId(1), 1);
sim.inject(NodeId(0), NodeId(1), 2);
sim.inject(NodeId(0), NodeId(1), 3);
let stats = sim.run();
assert_eq!(stats.messages_delivered, 3);
}
#[test]
fn interleaved_sleep_recv() {
let topology = TopologyBuilder::new(2)
.link(0u32, 1u32, Duration::from_millis(10))
.build();
let sim = TaskSimBuilder::<String>::new(topology, 42).build(|ctx| async move {
if ctx.id().as_u32() == 0 {
ctx.sleep(Duration::from_millis(5)).await;
ctx.send(NodeId(1), "first".to_string()).await;
ctx.sleep(Duration::from_millis(20)).await;
ctx.send(NodeId(1), "second".to_string()).await;
} else {
let m1 = ctx.recv().await;
assert_eq!(m1.payload, "first");
ctx.sleep(Duration::from_millis(5)).await;
let m2 = ctx.recv().await;
assert_eq!(m2.payload, "second");
}
});
let stats = sim.run();
assert_eq!(stats.messages_delivered, 2);
}
#[test]
fn no_route_drops_instead_of_delivering() {
let topology = TopologyBuilder::new(2).build();
let mut sim = TaskSimBuilder::<String>::new(topology, 42).build(|ctx| async move {
if ctx.id().as_u32() == 1 {
let _ = ctx.recv_timeout(Duration::from_millis(1)).await;
}
});
sim.inject(NodeId(0), NodeId(1), "unreachable?".to_string());
let stats = sim.run();
assert_eq!(stats.messages_delivered, 0);
assert_eq!(stats.messages_dropped, 1);
}
#[test]
fn fail_link_pre_run_drops_send() {
let topology = TopologyBuilder::new(2)
.link(0u32, 1u32, Duration::from_millis(10))
.build();
let mut sim = TaskSimBuilder::<&'static str>::new(topology, 42).build(|ctx| async move {
if ctx.id() == NodeId(0) {
ctx.send(NodeId(1), "blocked").await;
} else {
let _ = ctx.recv_timeout(Duration::from_millis(1)).await;
}
});
sim.fail_link(NodeId(0), NodeId(1));
let stats = sim.run();
assert_eq!(stats.messages_delivered, 0);
assert_eq!(stats.messages_dropped, 1);
}
#[test]
fn fail_link_pre_run_reroutes_send() {
let topology = TopologyBuilder::new(3)
.link(0u32, 1u32, Duration::from_millis(5))
.link(1u32, 2u32, Duration::from_millis(5))
.link(0u32, 2u32, Duration::from_millis(50))
.build();
let mut sim = TaskSimBuilder::<&'static str>::new(topology, 42).build(|ctx| async move {
if ctx.id() == NodeId(0) {
ctx.send(NodeId(2), "long way").await;
} else if ctx.id() == NodeId(2) {
let msg = ctx.recv().await;
assert_eq!(msg.payload, "long way");
assert_eq!(msg.received_at, Time::from_millis(50));
} else {
let _ = ctx.recv_timeout(Duration::from_millis(1)).await;
}
});
sim.fail_link(NodeId(0), NodeId(1));
let stats = sim.run();
assert_eq!(stats.messages_delivered, 1);
assert_eq!(stats.messages_dropped, 0);
}
#[test]
fn fail_node_pre_run_drops_sends_to_failed_node() {
let topology = TopologyBuilder::new(2)
.link(0u32, 1u32, Duration::from_millis(10))
.build();
let mut sim = TaskSimBuilder::<&'static str>::new(topology, 42).build(|ctx| async move {
if ctx.id() == NodeId(0) {
ctx.send(NodeId(1), "to_dead").await;
}
});
sim.fail_node(NodeId(1));
let stats = sim.run();
assert_eq!(stats.messages_delivered, 0);
assert_eq!(stats.messages_dropped, 1);
}
#[test]
fn partition_drops_sends_in_both_directions() {
let topology = TopologyBuilder::new(3)
.full_mesh(Duration::from_millis(10))
.build();
let mut sim = TaskSimBuilder::<&'static str>::new(topology, 42).build(|ctx| async move {
match ctx.id().as_u32() {
0 => {
ctx.send(NodeId(1), "blocked_a").await;
ctx.send(NodeId(2), "still_works").await;
}
1 => {
ctx.send(NodeId(0), "blocked_b").await;
let _ = ctx.recv_timeout(Duration::from_millis(1)).await;
}
_ => {
let msg = ctx.recv().await;
assert_eq!(msg.payload, "still_works");
}
}
});
sim.partition(NodeId(0), NodeId(1));
assert!(sim.is_partitioned(NodeId(0), NodeId(1)));
let stats = sim.run();
assert_eq!(stats.messages_dropped, 2);
assert_eq!(stats.messages_delivered, 1);
}
#[test]
fn node_context_can_fail_and_heal_link_mid_run() {
let topology = TopologyBuilder::new(3)
.link(0u32, 1u32, Duration::from_millis(5))
.link(1u32, 2u32, Duration::from_millis(5))
.link(0u32, 2u32, Duration::from_millis(50))
.build();
let arrivals: Rc<RefCell<Vec<(u32, Time)>>> = Rc::new(RefCell::new(Vec::new()));
let arrivals_for_recv = Rc::clone(&arrivals);
let sim = TaskSimBuilder::<u32>::new(topology, 42).build(move |ctx| {
let arrivals = Rc::clone(&arrivals_for_recv);
async move {
if ctx.id() == NodeId(0) {
ctx.send(NodeId(2), 1).await;
ctx.sleep(Duration::from_millis(100)).await;
ctx.fail_link(NodeId(0), NodeId(1));
ctx.send(NodeId(2), 2).await;
ctx.sleep(Duration::from_millis(100)).await;
ctx.heal_link(NodeId(0), NodeId(1));
ctx.send(NodeId(2), 3).await;
} else if ctx.id() == NodeId(2) {
for _ in 0..3 {
let msg = ctx.recv().await;
arrivals.borrow_mut().push((msg.payload, msg.received_at));
}
}
}
});
let _stats = sim.run();
let arrivals = arrivals.borrow();
assert_eq!(arrivals.len(), 3);
assert_eq!(arrivals[0], (1, Time::from_millis(10)));
assert_eq!(arrivals[1], (2, Time::from_millis(150)));
assert_eq!(arrivals[2], (3, Time::from_millis(210)));
}
#[test]
fn task_completes_early() {
let topology = TopologyBuilder::new(2)
.link(0u32, 1u32, Duration::from_millis(10))
.build();
let sim = TaskSimBuilder::<()>::new(topology, 42).build(|ctx| async move {
if ctx.id().as_u32() == 0 {
ctx.sleep(Duration::from_millis(10)).await;
} else {
ctx.sleep(Duration::from_millis(100)).await;
}
});
let stats = sim.run();
assert_eq!(stats.final_time, Time::from_millis(100));
assert_eq!(stats.tasks_completed, 2);
}
#[test]
fn recv_timeout_fires() {
let topology = TopologyBuilder::new(1).build();
let sim = TaskSimBuilder::<u32>::new(topology, 42).build(|ctx| async move {
let result = ctx.recv_timeout(Duration::from_millis(50)).await;
assert!(result.is_none());
assert_eq!(ctx.now(), Time::from_millis(50));
});
let stats = sim.run();
assert_eq!(stats.final_time, Time::from_millis(50));
assert_eq!(stats.tasks_completed, 1);
}
#[test]
fn recv_timeout_receives_before_deadline() {
let topology = TopologyBuilder::new(2)
.link(0u32, 1u32, Duration::from_millis(10))
.build();
let mut sim = TaskSimBuilder::<String>::new(topology, 42).build(|ctx| async move {
if ctx.id().as_u32() == 1 {
let result = ctx.recv_timeout(Duration::from_millis(100)).await;
assert!(result.is_some());
assert_eq!(result.unwrap().payload, "in-time");
}
});
sim.inject(NodeId(0), NodeId(1), "in-time".to_string());
let stats = sim.run();
assert_eq!(stats.messages_delivered, 1);
}
#[test]
fn select2_prefers_recv_when_available() {
let topology = TopologyBuilder::new(2)
.link(0u32, 1u32, Duration::from_millis(5))
.build();
let mut sim = TaskSimBuilder::<String>::new(topology, 42).build(|ctx| async move {
if ctx.id().as_u32() == 1 {
let either = select2(ctx.recv(), ctx.sleep(Duration::from_millis(100))).await;
match either {
Either::Left(env) => assert_eq!(env.payload, "hi"),
Either::Right(_) => panic!("expected recv to win"),
}
}
});
sim.inject(NodeId(0), NodeId(1), "hi".to_string());
let stats = sim.run();
assert_eq!(stats.messages_delivered, 1);
}
#[test]
fn select2_falls_back_to_sleep() {
let topology = TopologyBuilder::new(1).build();
let sim = TaskSimBuilder::<u32>::new(topology, 42).build(|ctx| async move {
let either = select2(ctx.recv(), ctx.sleep(Duration::from_millis(25))).await;
assert!(matches!(either, Either::Right(_)));
assert_eq!(ctx.now(), Time::from_millis(25));
});
let stats = sim.run();
assert_eq!(stats.final_time, Time::from_millis(25));
}
#[test]
fn shutdown_token_observed_by_tasks() {
let topology = TopologyBuilder::new(2)
.link(0u32, 1u32, Duration::from_millis(1))
.build();
let sim = TaskSimBuilder::<u32>::new(topology, 42).build(|ctx| async move {
let token = ctx.shutdown_token();
for i in 0..1000u32 {
if token.is_cancelled() {
break;
}
ctx.sleep(Duration::from_millis(1)).await;
if ctx.id().as_u32() == 0 {
ctx.send(NodeId(1), i).await;
}
}
});
sim.shutdown_token().cancel();
let stats = sim.run();
assert_eq!(stats.tasks_completed, 2);
assert_eq!(stats.messages_delivered, 0);
}
}