1#![allow(clippy::collapsible_if)]
3
4use crate::{Disconnect, Io, Link, LocalDuration, LocalTime};
5use log::*;
6
7use std::borrow::Cow;
8use std::collections::{BTreeMap, BTreeSet, VecDeque};
9use std::ops::{Deref, DerefMut, Range};
10use std::{fmt, io, net};
11
12use crate::StateMachine;
13
14#[cfg(feature = "quickcheck")]
15pub mod arbitrary;
16
17pub const MIN_LATENCY: LocalDuration = LocalDuration::from_millis(1);
19pub const MAX_EVENTS: usize = 2048;
21
22type NodeId = net::IpAddr;
25
26pub trait Peer<P>: Deref<Target = P> + DerefMut<Target = P> + 'static
28where
29 P: StateMachine,
30{
31 fn init(&mut self);
34 fn addr(&self) -> net::SocketAddr;
36}
37
38#[derive(Debug, Clone)]
40pub enum Input<M, D> {
41 Connecting {
43 addr: net::SocketAddr,
45 },
46 Connected {
48 addr: net::SocketAddr,
50 local_addr: net::SocketAddr,
52 link: Link,
54 },
55 Disconnected(net::SocketAddr, Disconnect<D>),
57 Received(net::SocketAddr, M),
59 Wake,
61}
62
63#[derive(Debug, Clone)]
65pub struct Scheduled<M, D> {
66 pub node: NodeId,
68 pub remote: net::SocketAddr,
71 pub input: Input<M, D>,
73}
74
75impl<M: fmt::Debug, D: fmt::Display> fmt::Display for Scheduled<M, D> {
76 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
77 match &self.input {
78 Input::Received(from, msg) => {
79 write!(f, "{} <- {} ({:?})", self.node, from, msg)
80 }
81 Input::Connected {
82 addr,
83 local_addr,
84 link: Link::Inbound,
85 ..
86 } => write!(f, "{} <== {}: Connected", local_addr, addr),
87 Input::Connected {
88 local_addr,
89 addr,
90 link: Link::Outbound,
91 ..
92 } => write!(f, "{} ==> {}: Connected", local_addr, addr),
93 Input::Connecting { addr } => {
94 write!(f, "{} => {}: Connecting", self.node, addr)
95 }
96 Input::Disconnected(addr, reason) => {
97 write!(f, "{} =/= {}: Disconnected: {}", self.node, addr, reason)
98 }
99 Input::Wake => {
100 write!(f, "{}: Tock", self.node)
101 }
102 }
103 }
104}
105
106#[derive(Debug)]
108pub struct Inbox<M, D> {
109 messages: BTreeMap<LocalTime, Scheduled<M, D>>,
112}
113
114impl<M: Clone, D: Clone> Inbox<M, D> {
115 fn insert(&mut self, mut time: LocalTime, msg: Scheduled<M, D>) {
117 while self.messages.contains_key(&time) {
119 time = time + MIN_LATENCY;
120 }
121 self.messages.insert(time, msg);
122 }
123
124 fn next(&mut self) -> Option<(LocalTime, Scheduled<M, D>)> {
126 self.messages
127 .iter()
128 .next()
129 .map(|(time, scheduled)| (*time, scheduled.clone()))
130 }
131
132 fn last(
134 &self,
135 node: &NodeId,
136 remote: &net::SocketAddr,
137 ) -> Option<(&LocalTime, &Scheduled<M, D>)> {
138 self.messages
139 .iter()
140 .rev()
141 .find(|(_, v)| &v.node == node && &v.remote == remote)
142 }
143}
144
145#[derive(Debug, Clone)]
147pub struct Options {
148 pub latency: Range<u64>,
150 pub failure_rate: f64,
153}
154
155impl Default for Options {
156 fn default() -> Self {
157 Self {
158 latency: Range::default(),
159 failure_rate: 0.,
160 }
161 }
162}
163
164pub struct Simulation<T>
166where
167 T: StateMachine,
168{
169 inbox: Inbox<<T::Message as ToOwned>::Owned, T::DisconnectReason>,
171 events: BTreeMap<NodeId, VecDeque<T::Event>>,
173 priority: VecDeque<Scheduled<<T::Message as ToOwned>::Owned, T::DisconnectReason>>,
175 latencies: BTreeMap<(NodeId, NodeId), LocalDuration>,
177 partitions: BTreeSet<(NodeId, NodeId)>,
179 connections: BTreeMap<(NodeId, NodeId), u16>,
181 attempts: BTreeSet<(NodeId, NodeId)>,
183 opts: Options,
185 start_time: LocalTime,
187 time: LocalTime,
189 rng: fastrand::Rng,
191}
192
193impl<T> Simulation<T>
194where
195 T: StateMachine + 'static,
196 T::DisconnectReason: Clone + Into<Disconnect<T::DisconnectReason>>,
197
198 <T::Message as ToOwned>::Owned: fmt::Debug + Clone,
199{
200 pub fn new(time: LocalTime, rng: fastrand::Rng, opts: Options) -> Self {
202 Self {
203 inbox: Inbox {
204 messages: BTreeMap::new(),
205 },
206 events: BTreeMap::new(),
207 priority: VecDeque::new(),
208 partitions: BTreeSet::new(),
209 latencies: BTreeMap::new(),
210 connections: BTreeMap::new(),
211 attempts: BTreeSet::new(),
212 opts,
213 start_time: time,
214 time,
215 rng,
216 }
217 }
218
219 pub fn is_done(&self) -> bool {
221 self.inbox.messages.is_empty()
222 }
223
224 #[allow(dead_code)]
226 pub fn elapsed(&self) -> LocalDuration {
227 self.time - self.start_time
228 }
229
230 pub fn is_settled(&self) -> bool {
233 self.inbox
234 .messages
235 .iter()
236 .all(|(_, s)| matches!(s.input, Input::Wake))
237 }
238
239 pub fn events(&mut self, node: &NodeId) -> impl Iterator<Item = T::Event> + '_ {
241 self.events.entry(*node).or_default().drain(..)
242 }
243
244 pub fn latency(&self, from: NodeId, to: NodeId) -> LocalDuration {
246 self.latencies
247 .get(&(from, to))
248 .cloned()
249 .map(|l| {
250 if l <= MIN_LATENCY {
251 l
252 } else {
253 let millis = l.as_millis();
256
257 if self.rng.bool() {
258 LocalDuration::from_millis(millis + self.rng.u128(0..millis))
260 } else {
261 LocalDuration::from_millis(millis - self.rng.u128(0..millis / 2))
263 }
264 }
265 })
266 .unwrap_or_else(|| MIN_LATENCY)
267 }
268
269 pub fn initialize<'a, P: Peer<T>>(self, peers: impl IntoIterator<Item = &'a mut P>) -> Self {
271 for peer in peers.into_iter() {
272 peer.init();
273 }
274 self
275 }
276
277 pub fn run_while<'a, P: Peer<T>>(
279 &mut self,
280 peers: impl IntoIterator<Item = &'a mut P>,
281 pred: impl Fn(&Self) -> bool,
282 ) {
283 let mut nodes: BTreeMap<_, _> = peers.into_iter().map(|p| (p.addr().ip(), p)).collect();
284
285 while self.step_(&mut nodes) {
286 if !pred(self) {
287 break;
288 }
289 }
290 }
291
292 pub fn step<'a, P: Peer<T>>(&mut self, peers: impl IntoIterator<Item = &'a mut P>) -> bool {
296 let mut nodes: BTreeMap<_, _> = peers.into_iter().map(|p| (p.addr().ip(), p)).collect();
297 self.step_(&mut nodes)
298 }
299
300 fn step_<P: Peer<T>>(&mut self, nodes: &mut BTreeMap<NodeId, &mut P>) -> bool {
301 if !self.opts.latency.is_empty() {
302 for (i, from) in nodes.keys().enumerate() {
304 for to in nodes.keys().skip(i + 1) {
305 let range = self.opts.latency.clone();
306 let latency = LocalDuration::from_millis(
307 self.rng
308 .u128(range.start as u128 * 1_000..range.end as u128 * 1_000),
309 );
310
311 self.latencies.entry((*from, *to)).or_insert(latency);
312 self.latencies.entry((*to, *from)).or_insert(latency);
313 }
314 }
315 }
316
317 if self.time.as_secs() % 10 == 0 {
323 for (i, x) in nodes.keys().enumerate() {
324 for y in nodes.keys().skip(i + 1) {
325 if self.is_fallible() {
326 self.partitions.insert((*x, *y));
327 } else {
328 self.partitions.remove(&(*x, *y));
329 }
330 }
331 }
332 }
333
334 for peer in nodes.values_mut() {
336 let ip = peer.addr().ip();
337
338 for o in peer.by_ref() {
339 self.schedule(&ip, o);
340 }
341 }
342 let priority = self.priority.pop_front().map(|s| (self.time, s));
344
345 if let Some((time, next)) = priority.or_else(|| self.inbox.next()) {
346 let elapsed = (time - self.start_time).as_millis();
347 if matches!(next.input, Input::Wake) {
348 trace!(target: "sim", "{:05} {}", elapsed, next);
349 } else {
350 info!(target: "sim", "{:05} {} ({})", elapsed, next, self.inbox.messages.len());
354 }
355 assert!(time >= self.time, "Time only moves forwards!");
356
357 self.time = time;
358 self.inbox.messages.remove(&time);
359
360 let Scheduled { input, node, .. } = next;
361
362 if let Some(ref mut p) = nodes.get_mut(&node) {
363 p.tick(time);
364
365 match input {
366 Input::Connecting { addr } => {
367 if self.attempts.insert((node, addr.ip())) {
368 p.attempted(&addr);
369 }
370 }
371 Input::Connected {
372 addr,
373 local_addr,
374 link,
375 } => {
376 let conn = (node, addr.ip());
377
378 let attempted = link.is_outbound() && self.attempts.remove(&conn);
379 if attempted || link.is_inbound() {
380 if self.connections.insert(conn, local_addr.port()).is_none() {
381 p.connected(addr, &local_addr, link);
382 }
383 }
384 }
385 Input::Disconnected(addr, reason) => {
386 let conn = (node, addr.ip());
387 let attempt = self.attempts.remove(&conn);
388 let connection = self.connections.remove(&conn).is_some();
389
390 assert!(!(attempt && connection));
392
393 if attempt || connection {
394 p.disconnected(&addr, reason);
395 }
396 }
397 Input::Wake => p.timer_expired(),
398 Input::Received(addr, msg) => {
399 p.message_received(&addr, Cow::Owned(msg));
400 }
401 }
402 for o in p.by_ref() {
403 self.schedule(&node, o);
404 }
405 } else {
406 panic!(
407 "Node {} not found when attempting to schedule {:?}",
408 node, input
409 );
410 }
411 }
412 !self.is_done()
413 }
414
415 pub fn schedule(
417 &mut self,
418 node: &NodeId,
419 out: Io<<T::Message as ToOwned>::Owned, T::Event, T::DisconnectReason, net::SocketAddr>,
420 ) {
421 let node = *node;
422
423 match out {
424 Io::Write(receiver, msg) => {
425 let port = if let Some(port) = self.connections.get(&(node, receiver.ip())) {
428 *port
429 } else {
430 return;
431 };
432
433 let sender: net::SocketAddr = (node, port).into();
434 if self.is_partitioned(sender.ip(), receiver.ip()) {
435 info!(
437 target: "sim",
438 "{} -> {} (DROPPED)",
439 sender, receiver,
440 );
441 return;
442 }
443
444 let latency = self.latency(node, receiver.ip());
447 let time = self
448 .inbox
449 .last(&receiver.ip(), &sender)
450 .map(|(k, _)| *k)
451 .unwrap_or_else(|| self.time);
452 let time = time + latency;
453 let elapsed = (time - self.start_time).as_millis();
454
455 info!(
456 target: "sim",
457 "{:05} {} -> {}: ({:?}) ({})",
458 elapsed, sender, receiver, &msg, latency
459 );
460
461 self.inbox.insert(
462 time,
463 Scheduled {
464 remote: sender,
465 node: receiver.ip(),
466 input: Input::Received(sender, msg),
467 },
468 );
469 }
470 Io::Connect(remote) => {
471 assert!(remote.ip() != node, "self-connections are not allowed");
472
473 let local_addr: net::SocketAddr = net::SocketAddr::new(node, self.rng.u16(8192..));
475 let latency = self.latency(node, remote.ip());
476
477 self.inbox.insert(
478 self.time + MIN_LATENCY,
479 Scheduled {
480 node,
481 remote,
482 input: Input::Connecting { addr: remote },
483 },
484 );
485
486 if self.is_partitioned(node, remote.ip()) {
488 log::info!(target: "sim", "{} -/-> {} (partitioned)", node, remote.ip());
489
490 if self.rng.bool() {
492 self.inbox.insert(
493 self.time + MIN_LATENCY,
494 Scheduled {
495 node,
496 remote,
497 input: Input::Disconnected(
498 remote,
499 Disconnect::ConnectionError(
500 io::Error::from(io::ErrorKind::UnexpectedEof).into(),
501 ),
502 ),
503 },
504 );
505 }
506 return;
507 }
508
509 self.inbox.insert(
510 self.time + latency,
512 Scheduled {
513 node: remote.ip(),
514 remote: local_addr,
515 input: Input::Connected {
516 addr: local_addr,
517 local_addr: remote,
518 link: Link::Inbound,
519 },
520 },
521 );
522 self.inbox.insert(
523 self.time + latency,
525 Scheduled {
526 remote,
527 node,
528 input: Input::Connected {
529 addr: remote,
530 local_addr,
531 link: Link::Outbound,
532 },
533 },
534 );
535 }
536 Io::Disconnect(remote, reason) => {
537 self.priority.push_back(Scheduled {
539 remote,
540 node,
541 input: Input::Disconnected(remote, reason.into()),
542 });
543
544 let port = if let Some(port) = self.connections.get(&(node, remote.ip())) {
551 *port
552 } else {
553 debug!(target: "sim", "Ignoring disconnect of {remote} from {node}");
554 return;
555 };
556 let local_addr: net::SocketAddr = (node, port).into();
557 let latency = self.latency(node, remote.ip());
558
559 self.inbox.insert(
561 self.time + latency,
562 Scheduled {
563 node: remote.ip(),
564 remote: local_addr,
565 input: Input::Disconnected(
566 local_addr,
567 Disconnect::ConnectionError(
568 io::Error::from(io::ErrorKind::ConnectionReset).into(),
569 ),
570 ),
571 },
572 );
573 }
574 Io::SetTimer(duration) => {
575 let time = self.time + duration;
576
577 if !matches!(
578 self.inbox.messages.get(&time),
579 Some(Scheduled {
580 input: Input::Wake,
581 ..
582 })
583 ) {
584 self.inbox.insert(
585 time,
586 Scheduled {
587 node,
588 remote: ([0, 0, 0, 0], 0).into(),
590 input: Input::Wake,
591 },
592 );
593 }
594 }
595 Io::Event(event) => {
596 let events = self.events.entry(node).or_insert_with(VecDeque::new);
597 if events.len() >= MAX_EVENTS {
598 warn!(target: "sim", "Dropping event: buffer is full");
599 } else {
600 events.push_back(event);
601 }
602 }
603 }
604 }
605
606 fn is_fallible(&self) -> bool {
608 self.rng.f64() % 1.0 < self.opts.failure_rate
609 }
610
611 fn is_partitioned(&self, a: NodeId, b: NodeId) -> bool {
613 self.partitions.contains(&(a, b)) || self.partitions.contains(&(b, a))
614 }
615}