use std::{
collections::HashMap,
num::NonZeroU64,
time::{Duration, Instant},
};
use minbft::{
output::TimeoutRequest,
timeout::{StopClass, TimeoutType},
Config, Error, MinBft, Output, PeerMessage, RequestPayload,
};
use rand::rngs::ThreadRng;
use serde::{Deserialize, Serialize};
use shared_ids::{ClientId, ReplicaId, RequestId};
use usig::{noop::UsigNoOp, Usig};
use rand::prelude::SliceRandom;
use anyhow::{anyhow, Result};
#[derive(Debug, Serialize, Deserialize, Clone, Copy, PartialEq, Eq)]
pub(super) struct DummyPayload(pub(super) u64, pub(super) bool);
impl RequestPayload for DummyPayload {
fn id(&self) -> RequestId {
RequestId::from_u64(self.0)
}
fn verify(&self, _id: ClientId) -> Result<()> {
self.1
.then_some(())
.ok_or_else(|| anyhow!("invalid request"))
}
}
type MinBftSetup = (
(
MinBft<DummyPayload, UsigNoOp>,
Output<DummyPayload, UsigNoOp>,
),
TimeoutHandler,
);
type SetupSet = (
HashMap<ReplicaId, MinBft<DummyPayload, UsigNoOp>>,
HashMap<ReplicaId, TimeoutHandler>,
);
pub(crate) fn minimal_setup(n: u64, t: u64, id: ReplicaId, checkpoint_period: u64) -> MinBftSetup {
let checkpoint_period = NonZeroU64::new(checkpoint_period).unwrap();
(
MinBft::new(
UsigNoOp::default(),
Config {
n: n.try_into().unwrap(),
t,
id,
max_batch_size: Some(1.try_into().expect("> 0")),
batch_timeout: Duration::from_secs(1),
initial_timeout_duration: Duration::from_secs(1),
checkpoint_period,
},
)
.unwrap(),
TimeoutHandler::default(),
)
}
#[derive(Debug, Clone, Copy)]
struct TimeoutEntry {
timeout_type: TimeoutType,
timeout_deadline: Instant,
stop_class: StopClass,
}
#[derive(Debug, Clone, Default)]
pub(crate) struct TimeoutHandler(HashMap<TimeoutType, (TimeoutEntry, bool)>);
impl TimeoutHandler {
pub(crate) fn handle_timeout_request(&mut self, timeout_request: TimeoutRequest) {
if let TimeoutRequest::Start(timeout) = timeout_request {
if self.0.contains_key(&timeout.timeout_type) {
return;
}
let new_entry = TimeoutEntry {
timeout_type: timeout.timeout_type,
timeout_deadline: Instant::now() + timeout.duration,
stop_class: timeout.stop_class,
};
self.0.insert(new_entry.timeout_type, (new_entry, false));
}
if let TimeoutRequest::Stop(timeout) = timeout_request {
if !self.0.contains_key(&timeout.timeout_type) {
return;
}
let (current_timeout, _) = self.0.get(&timeout.timeout_type).unwrap();
if current_timeout.stop_class == timeout.stop_class {
self.0.remove(&timeout.timeout_type);
}
}
}
pub(crate) fn handle_timeout_requests(&mut self, timeout_requests: Vec<TimeoutRequest>) {
for timeout_request in timeout_requests {
self.handle_timeout_request(timeout_request);
}
}
pub(crate) fn retrieve_timeouts_ordered(&mut self) -> Vec<TimeoutType> {
let mut timeouts: Vec<TimeoutEntry> = self
.0
.values()
.filter(|(_, retrieved)| !retrieved)
.map(|(e, _)| *e)
.collect();
timeouts.sort_by(|x, y| x.timeout_deadline.cmp(&y.timeout_deadline));
let retrieved_timeouts = timeouts.iter().map(|e| e.timeout_type).collect();
for retrieved_timeout in &retrieved_timeouts {
let updated = (self.0.get(retrieved_timeout).unwrap().0, true);
self.0.insert(*retrieved_timeout, updated);
}
retrieved_timeouts
}
}
pub(crate) fn setup_set(n: u64, t: u64, checkpoint_period: u64) -> SetupSet {
let mut minbfts = HashMap::new();
let mut timeout_handlers = HashMap::new();
let mut all_broadcasts = Vec::new();
let mut hello_done_count = 0;
for i in 0..n {
let replica = ReplicaId::from_u64(i);
let (
(
minbft,
Output {
broadcasts,
responses,
timeout_requests,
errors,
ready_for_client_requests,
primary: _,
view_info: _,
round: _,
},
),
timeout_handler,
) = minimal_setup(n, t, replica, checkpoint_period);
assert_eq!(responses.len(), 0);
assert_eq!(errors.len(), 0);
assert_eq!(timeout_requests.len(), 0);
if ready_for_client_requests {
hello_done_count += 1;
}
assert!(!ready_for_client_requests || n == 1);
all_broadcasts.push((replica, broadcasts));
minbfts.insert(replica, minbft);
timeout_handlers.insert(replica, timeout_handler);
}
for (id, broadcasts) in all_broadcasts.into_iter() {
for broadcast in Vec::from(broadcasts).into_iter() {
for (_, minbft) in minbfts.iter_mut().filter(|(i, _)| **i != id) {
let Output {
broadcasts,
responses,
timeout_requests,
errors,
ready_for_client_requests,
primary: _,
view_info: _,
round: _,
} = minbft.handle_peer_message(id, broadcast.clone());
assert_eq!(broadcasts.len(), 0);
assert_eq!(responses.len(), 0);
assert_eq!(errors.len(), 0);
assert_eq!(timeout_requests.len(), 0);
if ready_for_client_requests {
hello_done_count += 1;
}
}
}
}
assert_eq!(hello_done_count, n);
(minbfts, timeout_handlers)
}
type PeerMessageTest =
PeerMessage<<UsigNoOp as Usig>::Attestation, DummyPayload, <UsigNoOp as Usig>::Signature>;
#[derive(Default)]
pub(crate) struct CollectedOutput {
pub(crate) responses: HashMap<ReplicaId, Vec<(ClientId, DummyPayload)>>,
pub(crate) errors: HashMap<ReplicaId, Vec<Error>>,
pub(crate) timeout_requests: HashMap<ReplicaId, Vec<TimeoutRequest>>,
}
impl CollectedOutput {
pub(crate) fn timeouts_to_handle(
&self,
timeout_handlers: &mut HashMap<ReplicaId, TimeoutHandler>,
rng: &mut ThreadRng,
) -> HashMap<ReplicaId, Vec<TimeoutType>> {
let mut timeouts_to_handle = HashMap::new();
let mut replica_ids: Vec<ReplicaId> = self.timeout_requests.keys().cloned().collect();
replica_ids.shuffle(rng);
for rep_id in &replica_ids {
let timeout_requests = self.timeout_requests.get(rep_id).unwrap();
let timeout_handler = timeout_handlers.get_mut(rep_id).unwrap();
timeout_handler.handle_timeout_requests(timeout_requests.to_vec());
timeouts_to_handle.insert(*rep_id, timeout_handler.retrieve_timeouts_ordered());
}
timeouts_to_handle
}
}
pub(crate) fn handle_broadcasts(
minbfts: &mut HashMap<ReplicaId, MinBft<DummyPayload, UsigNoOp>>,
broadcasts_with_origin: Vec<(ReplicaId, Box<[PeerMessageTest]>)>,
collected_output: &mut CollectedOutput,
rng: &mut ThreadRng,
) {
let mut all_broadcasts = Vec::new();
for (from, messages_to_broadcast) in broadcasts_with_origin {
for message_to_broadcast in Vec::from(messages_to_broadcast).into_iter() {
let mut replica_ids: Vec<ReplicaId> =
minbfts.keys().filter(|id| **id != from).cloned().collect();
replica_ids.shuffle(rng);
for rep_id in &replica_ids {
let minbft = minbfts.get_mut(rep_id).unwrap();
let Output {
broadcasts,
responses,
timeout_requests: timeouts,
errors,
ready_for_client_requests,
primary: _,
view_info: _,
round: _,
} = minbft.handle_peer_message(from, message_to_broadcast.clone());
assert!(ready_for_client_requests);
collected_output
.responses
.entry(*rep_id)
.or_default()
.append(&mut Vec::from(responses));
collected_output
.errors
.entry(*rep_id)
.or_default()
.append(&mut Vec::from(errors));
collected_output
.timeout_requests
.entry(*rep_id)
.or_default()
.append(&mut Vec::from(timeouts));
if !broadcasts.is_empty() {
all_broadcasts.push((*rep_id, broadcasts));
}
}
}
}
if !all_broadcasts.is_empty() {
handle_broadcasts(minbfts, all_broadcasts, collected_output, rng);
}
}
pub(crate) fn try_client_request(
minbfts: &mut HashMap<ReplicaId, MinBft<DummyPayload, UsigNoOp>>,
client_id: ClientId,
payload: DummyPayload,
rng: &mut ThreadRng,
) -> CollectedOutput {
let mut collected_output = CollectedOutput::default();
let mut all_broadcasts = Vec::new();
let mut replica_ids: Vec<ReplicaId> = minbfts.keys().cloned().collect();
replica_ids.shuffle(rng);
for rep_id in &replica_ids {
let minbft = minbfts.get_mut(rep_id).unwrap();
let Output {
broadcasts,
responses,
timeout_requests: timeouts,
errors,
ready_for_client_requests,
primary: _,
view_info: _,
round: _,
} = minbft.handle_client_message(client_id, payload);
assert!(ready_for_client_requests);
collected_output
.responses
.entry(*rep_id)
.or_default()
.append(&mut Vec::from(responses));
collected_output
.errors
.entry(*rep_id)
.or_default()
.append(&mut Vec::from(errors));
collected_output
.timeout_requests
.entry(*rep_id)
.or_default()
.append(&mut Vec::from(timeouts));
if !broadcasts.is_empty() {
all_broadcasts.push((*rep_id, broadcasts));
}
}
handle_broadcasts(minbfts, all_broadcasts, &mut collected_output, rng);
collected_output
}
pub(crate) fn force_timeout(
minbfts: &mut HashMap<ReplicaId, MinBft<DummyPayload, UsigNoOp>>,
timeouts: &HashMap<ReplicaId, Vec<TimeoutType>>,
rng: &mut ThreadRng,
) -> CollectedOutput {
let mut collected_output = CollectedOutput::default();
let mut all_broadcasts = Vec::new();
let mut replica_ids: Vec<ReplicaId> = minbfts.keys().cloned().collect();
replica_ids.shuffle(rng);
for rep_id in &replica_ids {
let minbft = minbfts.get_mut(rep_id).unwrap();
if let Some(timeouts_to_handle) = timeouts.get(rep_id) {
for timeout_to_handle in timeouts_to_handle {
let timeout_type = timeout_to_handle.to_owned();
let Output {
broadcasts,
responses,
timeout_requests: timeouts,
errors,
ready_for_client_requests,
primary: _,
view_info: _,
round: _,
} = minbft.handle_timeout(timeout_type);
assert!(ready_for_client_requests);
collected_output
.responses
.entry(*rep_id)
.or_default()
.append(&mut Vec::from(responses));
collected_output
.errors
.entry(*rep_id)
.or_default()
.append(&mut Vec::from(errors));
collected_output
.timeout_requests
.entry(*rep_id)
.or_default()
.append(&mut Vec::from(timeouts));
if !broadcasts.is_empty() {
all_broadcasts.push((*rep_id, broadcasts));
}
}
}
}
handle_broadcasts(minbfts, all_broadcasts, &mut collected_output, rng);
collected_output
}
pub(crate) fn remove_random_replicas_from_hashmap(
minbfts: &mut HashMap<ReplicaId, MinBft<DummyPayload, UsigNoOp>>,
amount_to_keep: usize,
explicitly_to_keep: Option<ReplicaId>,
rng: &mut ThreadRng,
) {
assert!(minbfts.len() >= amount_to_keep);
let mut replica_ids: Vec<ReplicaId> = minbfts.keys().cloned().collect();
replica_ids.shuffle(rng);
replica_ids.truncate(amount_to_keep);
if let Some(explicitly_to_keep) = explicitly_to_keep {
assert!(0 <= amount_to_keep.try_into().unwrap());
if amount_to_keep != 0 && !replica_ids.contains(&explicitly_to_keep) {
replica_ids.pop();
replica_ids.push(explicitly_to_keep);
}
}
minbfts.retain(|i, _| replica_ids.contains(i));
}