1#![allow(missing_docs)]
2
3use std::collections::HashMap;
4use std::convert::Infallible;
5use std::convert::TryInto;
6use std::sync::Arc;
7
8use futures::channel::mpsc;
9use futures::future::BoxFuture;
10use futures::future::FutureExt;
11use futures::lock::Mutex;
12use futures::sink::SinkExt;
13use futures::stream::Stream;
14use thiserror::Error;
15
16use crate::append::AppendError;
17use crate::communicator::Acceptance;
18use crate::communicator::AcceptanceFor;
19use crate::communicator::Committed;
20use crate::communicator::Communicator;
21use crate::communicator::Vote;
22use crate::communicator::VoteFor;
23use crate::error::ShutDown;
24use crate::invocation::AbstainOf;
25use crate::invocation::CoordNumOf;
26use crate::invocation::Invocation;
27use crate::invocation::LogEntryOf;
28use crate::invocation::NayOf;
29use crate::invocation::NodeIdOf;
30use crate::invocation::NodeOf;
31use crate::invocation::RoundNumOf;
32use crate::invocation::YeaOf;
33use crate::retry::RetryPolicy;
34use crate::LogEntry;
35use crate::NodeInfo;
36use crate::RequestHandler;
37
38#[derive(
40 Clone, Copy, Debug, Default, Eq, Hash, PartialEq, serde::Deserialize, serde::Serialize,
41)]
42pub struct PrototypingNode(usize);
43
44static NODE_ID_DISPENSER: std::sync::atomic::AtomicUsize = std::sync::atomic::AtomicUsize::new(0);
45
46impl PrototypingNode {
47 pub fn new() -> Self {
48 Self(NODE_ID_DISPENSER.fetch_add(1, std::sync::atomic::Ordering::Relaxed))
49 }
50
51 pub fn with_id(id: usize) -> Self {
52 Self(id)
53 }
54}
55
56impl NodeInfo for PrototypingNode {
57 type Id = usize;
58
59 fn id(&self) -> Self::Id {
60 self.0
61 }
62}
63
64#[derive(Debug)]
65pub struct RetryIndefinitely<I>(u64, crate::util::PhantomSend<I>);
66
67impl<I> RetryIndefinitely<I> {
68 pub fn without_pausing() -> Self {
69 Self(0, crate::util::PhantomSend::new())
70 }
71
72 pub fn pausing_up_to(duration: std::time::Duration) -> Self {
73 Self(duration.as_millis() as u64, crate::util::PhantomSend::new())
74 }
75}
76
77impl<I: Invocation> RetryPolicy for RetryIndefinitely<I> {
78 type Invocation = I;
79 type Error = Infallible;
80 type StaticError = ShutDown;
81 type Future = BoxFuture<'static, Result<(), Self::Error>>;
82
83 fn eval(&mut self, _err: AppendError<Self::Invocation>) -> Self::Future {
84 let limit = self.0;
85
86 async move {
87 if limit > 0 {
88 use rand::Rng;
89
90 let delay = rand::thread_rng().gen_range(0..=limit);
91 let delay = std::time::Duration::from_millis(delay);
92
93 sleep(delay).await;
94 }
95
96 Ok(())
97 }
98 .boxed()
99 }
100}
101
102type RequestHandlers<I> = HashMap<NodeIdOf<I>, RequestHandler<I>>;
103type EventListeners<I> = Vec<mpsc::Sender<DirectCommunicatorEvent<I>>>;
104type PacketLossRates<I> = HashMap<(NodeIdOf<I>, NodeIdOf<I>), f32>;
105type E2eDelays<I> = HashMap<(NodeIdOf<I>, NodeIdOf<I>), rand_distr::Normal<f32>>;
106
107#[derive(Debug)]
108pub struct DirectCommunicators<I: Invocation> {
109 #[allow(clippy::type_complexity)]
110 request_handlers: Arc<Mutex<RequestHandlers<I>>>,
111 default_packet_loss: f32,
112 default_e2e_delay: rand_distr::Normal<f32>,
113 packet_loss: Arc<Mutex<PacketLossRates<I>>>,
114 e2e_delay: Arc<Mutex<E2eDelays<I>>>,
115 event_listeners: Arc<Mutex<EventListeners<I>>>,
116}
117
118impl<I: Invocation> DirectCommunicators<I> {
119 pub fn new() -> Self {
120 Self::with_characteristics(0.0, rand_distr::Normal::new(0.0, 0.0).unwrap())
121 }
122
123 pub fn with_characteristics(packet_loss: f32, e2e_delay: rand_distr::Normal<f32>) -> Self {
124 Self {
125 request_handlers: Arc::new(Mutex::new(HashMap::new())),
126 default_packet_loss: packet_loss,
127 default_e2e_delay: e2e_delay,
128 packet_loss: Arc::new(Mutex::new(HashMap::new())),
129 e2e_delay: Arc::new(Mutex::new(HashMap::new())),
130 event_listeners: Arc::new(Mutex::new(Vec::new())),
131 }
132 }
133
134 pub async fn set_packet_loss(&mut self, from: NodeIdOf<I>, to: NodeIdOf<I>, packet_loss: f32) {
135 let mut link = self.packet_loss.lock().await;
136 link.insert((from, to), packet_loss);
137 }
138
139 pub async fn set_delay(
140 &mut self,
141 from: NodeIdOf<I>,
142 to: NodeIdOf<I>,
143 delay: rand_distr::Normal<f32>,
144 ) {
145 let mut link = self.e2e_delay.lock().await;
146 link.insert((from, to), delay);
147 }
148
149 pub async fn register(&self, node_id: NodeIdOf<I>, handler: RequestHandler<I>) {
150 let mut handlers = self.request_handlers.lock().await;
151 handlers.insert(node_id, handler);
152 }
153
154 pub fn events(&self) -> impl Stream<Item = DirectCommunicatorEvent<I>> {
155 let (send, recv) = mpsc::channel(16);
156
157 futures::executor::block_on(async {
158 let mut listeners = self.event_listeners.lock().await;
159 listeners.push(send);
160 });
161
162 recv
163 }
164
165 pub fn create_communicator_for(&self, node_id: NodeIdOf<I>) -> DirectCommunicator<I> {
166 DirectCommunicator {
167 set: self.clone(),
168 node_id,
169 }
170 }
171}
172
173impl<I: Invocation> Clone for DirectCommunicators<I> {
174 fn clone(&self) -> Self {
175 Self {
176 request_handlers: Arc::clone(&self.request_handlers),
177 default_packet_loss: self.default_packet_loss,
178 default_e2e_delay: self.default_e2e_delay,
179 packet_loss: Arc::clone(&self.packet_loss),
180 e2e_delay: Arc::clone(&self.e2e_delay),
181 event_listeners: Arc::clone(&self.event_listeners),
182 }
183 }
184}
185
186impl<I: Invocation> Default for DirectCommunicators<I> {
187 fn default() -> Self {
188 Self::new()
189 }
190}
191
192#[derive(Clone, Debug)]
193pub struct DirectCommunicatorEvent<I: Invocation> {
194 pub sender: NodeIdOf<I>,
195 pub receiver: NodeIdOf<I>,
196 pub e2e_delay: std::time::Duration,
197 pub dropped: bool,
198 pub payload: DirectCommunicatorPayload<I>,
199}
200
201#[derive(Clone, Debug)]
202pub enum DirectCommunicatorPayload<I: Invocation> {
203 Prepare {
204 round_num: RoundNumOf<I>,
205 coord_num: CoordNumOf<I>,
206 },
207 Promise(bool),
208 Propose {
209 round_num: RoundNumOf<I>,
210 coord_num: CoordNumOf<I>,
211 log_entry: Arc<LogEntryOf<I>>,
212 },
213 Accept(bool),
214 Commit {
215 round_num: RoundNumOf<I>,
216 coord_num: CoordNumOf<I>,
217 log_entry: Arc<LogEntryOf<I>>,
218 },
219 CommitById {
220 round_num: RoundNumOf<I>,
221 coord_num: CoordNumOf<I>,
222 },
223 Committed(bool),
224}
225
226#[derive(Debug, Error)]
227pub enum DirectCommunicatorError {
228 #[error("other")]
229 Other,
230
231 #[error("timeout")]
232 Timeout,
233}
234
235#[derive(Debug)]
236pub struct DirectCommunicator<I: Invocation> {
237 set: DirectCommunicators<I>,
238 node_id: NodeIdOf<I>,
239}
240
241impl<I: Invocation> Clone for DirectCommunicator<I> {
242 fn clone(&self) -> Self {
243 Self {
244 set: self.set.clone(),
245 node_id: self.node_id,
246 }
247 }
248}
249
250macro_rules! send_fn {
251 (
252 $self:ident, $receivers:ident $(, $non_copy_arg:ident)* ;
253 $method:ident $(, $arg:ident)* ;
254 $request_payload:expr;
255 $response_payload:expr;
256 ) => {{
257 $receivers
258 .iter()
259 .map(move |receiver| {
260 let this = $self.clone();
261 let receiver_id = receiver.id();
262
263 $( send_fn!(@ $non_copy_arg); )*
264
265 (
266 receiver,
267 async move {
268 let (packet_loss_rate_there, packet_loss_rate_back) = {
269 let per_link = this.set.packet_loss.lock().await;
270
271 let there = per_link.get(&(this.node_id, receiver_id)).copied();
272 let there = there.unwrap_or(this.set.default_packet_loss);
273
274 let back = per_link.get(&(receiver_id, this.node_id)).copied();
275 let back = back.unwrap_or(this.set.default_packet_loss);
276
277 (there, back)
278 };
279 let (e2e_delay_distr_there, e2e_delay_distr_back) = {
280 let per_link = this.set.e2e_delay.lock().await;
281
282 let there = per_link.get(&(this.node_id, receiver_id)).copied();
283 let there = there.unwrap_or(this.set.default_e2e_delay);
284
285 let back = per_link.get(&(receiver_id, this.node_id)).copied();
286 let back = back.unwrap_or(this.set.default_e2e_delay);
287
288 (there, back)
289 };
290
291 let e2e_delay = delay(&e2e_delay_distr_there);
292 let dropped = roll_for_failure(packet_loss_rate_there);
293
294 {
295 let listeners = this.set.event_listeners.lock().await;
296 for mut l in listeners.iter().cloned() {
297 let _ = l.send(DirectCommunicatorEvent {
298 sender: this.node_id,
299 receiver: receiver_id,
300 e2e_delay,
301 dropped,
302 payload: $request_payload,
303 }).await;
304 }
305 }
306
307 sleep(e2e_delay).await;
308
309 if dropped {
310 return Err(DirectCommunicatorError::Timeout);
311 }
312
313 let result = {
314 let handlers = this.set.request_handlers.lock().await;
315 let handler = match handlers.get(&receiver_id) {
316 Some(handler) => handler,
317 None => return Err(DirectCommunicatorError::Other),
318 };
319
320 handler.$method($($arg),*)
321 }
322 .await;
323 let response = result
324 .try_into()
325 .map_err(|_| DirectCommunicatorError::Other);
326
327 let e2e_delay = delay(&e2e_delay_distr_back);
328 let dropped = roll_for_failure(packet_loss_rate_back);
329
330 {
331 let listeners = this.set.event_listeners.lock().await;
332 for mut l in listeners.iter().cloned() {
333 let _ = l.send(DirectCommunicatorEvent {
334 sender: receiver_id,
335 receiver: this.node_id,
336 e2e_delay,
337 dropped,
338 payload: $response_payload(&response),
339 }).await;
340 }
341 }
342
343 sleep(e2e_delay).await;
344
345 if dropped {
346 return Err(DirectCommunicatorError::Timeout);
347 }
348
349 response
350 }
351 .boxed(),
352 )
353 })
354 .collect()
355 }};
356
357 (@ $non_copy_arg:ident) => {
358 let $non_copy_arg = $non_copy_arg.clone();
359 }
360}
361
362impl<I: Invocation + 'static> Communicator for DirectCommunicator<I> {
363 type Node = NodeOf<I>;
364
365 type RoundNum = RoundNumOf<I>;
366 type CoordNum = CoordNumOf<I>;
367
368 type LogEntry = LogEntryOf<I>;
369
370 type Error = DirectCommunicatorError;
371
372 type SendPrepare = BoxFuture<'static, Result<VoteFor<Self>, Self::Error>>;
373 type Abstain = AbstainOf<I>;
374
375 type SendProposal = BoxFuture<'static, Result<AcceptanceFor<Self>, Self::Error>>;
376 type Yea = YeaOf<I>;
377 type Nay = NayOf<I>;
378
379 type SendCommit = BoxFuture<'static, Result<Committed, Self::Error>>;
380 type SendCommitById = BoxFuture<'static, Result<Committed, Self::Error>>;
381
382 fn send_prepare<'a>(
383 &mut self,
384 receivers: &'a [Self::Node],
385 round_num: Self::RoundNum,
386 coord_num: Self::CoordNum,
387 ) -> Vec<(&'a Self::Node, Self::SendPrepare)> {
388 send_fn!(
389 self, receivers;
390 handle_prepare, round_num, coord_num;
391 DirectCommunicatorPayload::Prepare { round_num, coord_num };
392 |r| DirectCommunicatorPayload::Promise(matches!(r, &Ok(Vote::Given(_))));
393 )
394 }
395
396 fn send_proposal<'a>(
397 &mut self,
398 receivers: &'a [Self::Node],
399 round_num: Self::RoundNum,
400 coord_num: Self::CoordNum,
401 log_entry: Arc<Self::LogEntry>,
402 ) -> Vec<(&'a Self::Node, Self::SendProposal)> {
403 send_fn!(
404 self, receivers, log_entry;
405 handle_proposal, round_num, coord_num, log_entry;
406 DirectCommunicatorPayload::Propose { round_num, coord_num, log_entry: log_entry.clone() };
407 |r| DirectCommunicatorPayload::Accept(matches!(r, &Ok(Acceptance::Given(_))));
408 )
409 }
410
411 fn send_commit<'a>(
412 &mut self,
413 receivers: &'a [Self::Node],
414 round_num: Self::RoundNum,
415 coord_num: Self::CoordNum,
416 log_entry: Arc<Self::LogEntry>,
417 ) -> Vec<(&'a Self::Node, Self::SendCommit)> {
418 send_fn!(
419 self, receivers, log_entry;
420 handle_commit, round_num, coord_num, log_entry;
421 DirectCommunicatorPayload::Commit { round_num, coord_num, log_entry: log_entry.clone() };
422 |r| DirectCommunicatorPayload::Committed(matches!(r, &Ok(_)));
423 )
424 }
425
426 fn send_commit_by_id<'a>(
427 &mut self,
428 receivers: &'a [Self::Node],
429 round_num: Self::RoundNum,
430 coord_num: Self::CoordNum,
431 log_entry_id: <Self::LogEntry as LogEntry>::Id,
432 ) -> Vec<(&'a Self::Node, Self::SendCommitById)> {
433 send_fn!(
434 self, receivers;
435 handle_commit_by_id, round_num, coord_num, log_entry_id;
436 DirectCommunicatorPayload::CommitById { round_num, coord_num };
437 |r| DirectCommunicatorPayload::Committed(matches!(r, &Ok(_)));
438 )
439 }
440}
441
442fn roll_for_failure(rate: f32) -> bool {
443 use rand::Rng;
444
445 rand::thread_rng().gen::<f32>() < rate
446}
447
448async fn sleep(duration: std::time::Duration) {
449 if duration > std::time::Duration::ZERO {
450 futures_timer::Delay::new(duration).await;
451 }
452}
453
454fn delay(distr: &rand_distr::Normal<f32>) -> std::time::Duration {
455 use rand::distributions::Distribution;
456
457 let delay_ms = distr.sample(&mut rand::thread_rng());
458 let delay_ms = delay_ms as u64;
459
460 std::time::Duration::from_millis(delay_ms)
461}