use crate::cx::Cx;
use crate::error::{Error, ErrorKind, Result};
use crate::time::timeout;
use crate::types::{Outcome, Time};
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, VecDeque};
use std::sync::{Arc, Mutex};
use std::time::Duration;
use super::types::{
ConsensusBatch, ConsensusRequest, ConsensusResponse, MessageCertificate, MessageDigest,
PhaseKind, ReplicaId, SequenceNumber, ViewNumber,
};
#[derive(Debug, Clone)]
pub struct PbftConfig {
pub replica_count: usize,
pub fault_tolerance: usize,
pub preprepare_timeout: Duration,
pub prepare_timeout: Duration,
pub commit_timeout: Duration,
pub view_change_timeout: Duration,
pub max_batch_size: usize,
pub batch_timeout: Duration,
}
impl PbftConfig {
pub fn new(replica_count: usize, fault_tolerance: usize) -> Result<Self> {
if replica_count < 3 * fault_tolerance + 1 {
return Err(Error::new(ErrorKind::InvalidInput));
}
Ok(Self {
replica_count,
fault_tolerance,
preprepare_timeout: Duration::from_secs(5),
prepare_timeout: Duration::from_secs(5),
commit_timeout: Duration::from_secs(5),
view_change_timeout: Duration::from_secs(10),
max_batch_size: 100,
batch_timeout: Duration::from_millis(10),
})
}
pub fn is_valid(&self) -> bool {
self.replica_count > 3 * self.fault_tolerance
}
pub fn quorum_size(&self) -> usize {
2 * self.fault_tolerance + 1
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum PbftMessage {
Request(ConsensusRequest),
PrePrepare {
view: ViewNumber,
sequence: SequenceNumber,
digest: MessageDigest,
batch: ConsensusBatch,
},
Prepare {
view: ViewNumber,
sequence: SequenceNumber,
digest: MessageDigest,
replica_id: ReplicaId,
},
Commit {
view: ViewNumber,
sequence: SequenceNumber,
digest: MessageDigest,
replica_id: ReplicaId,
},
ViewChange {
new_view: ViewNumber,
replica_id: ReplicaId,
certificates: Vec<MessageCertificate>,
},
NewView {
view: ViewNumber,
view_change_msgs: Vec<PbftMessage>,
preprepare_msgs: Vec<PbftMessage>,
},
}
impl PbftMessage {
pub fn digest(&self) -> Result<MessageDigest> {
MessageDigest::of(self)
}
pub fn phase(&self) -> PhaseKind {
match self {
PbftMessage::PrePrepare { .. } => PhaseKind::PrePrepare,
PbftMessage::Prepare { .. } => PhaseKind::Prepare,
PbftMessage::Commit { .. } => PhaseKind::Commit,
PbftMessage::ViewChange { .. } => PhaseKind::ViewChange,
PbftMessage::NewView { .. } => PhaseKind::NewView,
PbftMessage::Request(_) => PhaseKind::PrePrepare, }
}
}
#[derive(Debug, Clone)]
pub struct PbftState {
pub view: ViewNumber,
pub sequence: SequenceNumber,
pub log: HashMap<SequenceNumber, LogEntry>,
pub pending_requests: VecDeque<ConsensusRequest>,
pub last_executed: SequenceNumber,
pub view_change_state: Option<ViewChangeState>,
}
#[derive(Debug, Clone)]
pub struct LogEntry {
pub batch: ConsensusBatch,
pub digest: MessageDigest,
pub view: ViewNumber,
pub preprepared: bool,
pub prepare_msgs: HashMap<ReplicaId, PbftMessage>,
pub commit_msgs: HashMap<ReplicaId, PbftMessage>,
pub result: Option<Outcome<Vec<u8>, String>>,
}
#[derive(Debug, Clone)]
pub struct ViewChangeState {
pub target_view: ViewNumber,
pub view_change_msgs: HashMap<ReplicaId, PbftMessage>,
pub sent_view_change: bool,
pub started_at: Time,
}
pub trait PbftTransport: Send + Sync {
fn send_to_replica(
&self,
replica_id: &ReplicaId,
message: PbftMessage,
) -> impl std::future::Future<Output = Result<()>> + Send;
fn broadcast(
&self,
message: PbftMessage,
) -> impl std::future::Future<Output = Result<()>> + Send;
fn receive(&self) -> impl std::future::Future<Output = Result<PbftMessage>> + Send;
}
pub struct PbftNode<T: PbftTransport> {
replica_id: ReplicaId,
config: PbftConfig,
state: Arc<Mutex<PbftState>>,
transport: T,
}
impl<T: PbftTransport> PbftNode<T> {
pub fn new(replica_id: ReplicaId, config: PbftConfig, transport: T) -> Result<Self> {
if !config.is_valid() {
return Err(Error::new(ErrorKind::InvalidInput));
}
let state = PbftState {
view: ViewNumber::new(0),
sequence: SequenceNumber::new(0),
log: HashMap::new(),
pending_requests: VecDeque::new(),
last_executed: SequenceNumber::new(0),
view_change_state: None,
};
Ok(Self {
replica_id,
config,
state: Arc::new(Mutex::new(state)),
transport,
})
}
pub fn is_primary(&self) -> bool {
let state = self.state.lock().unwrap();
let primary_idx = state.view.primary(self.config.replica_count);
self.replica_id
.as_str()
.parse::<usize>()
.unwrap_or(usize::MAX)
== primary_idx
}
pub async fn submit_request(&self, cx: &Cx, request: ConsensusRequest) -> Result<()> {
{
let mut state = self.state.lock().unwrap();
state.pending_requests.push_back(request);
}
if self.is_primary() {
self.try_create_batch(cx).await?;
}
Ok(())
}
async fn try_create_batch(&self, cx: &Cx) -> Result<()> {
let (batch, sequence, view) = {
let mut state = self.state.lock().unwrap();
if state.pending_requests.is_empty() {
return Ok(()); }
let mut requests = Vec::new();
while requests.len() < self.config.max_batch_size && !state.pending_requests.is_empty()
{
if let Some(request) = state.pending_requests.pop_front() {
requests.push(request);
}
}
let batch = ConsensusBatch::new(requests);
let sequence = state.sequence;
let view = state.view;
state.sequence = state.sequence.next();
(batch, sequence, view)
};
self.send_preprepare(cx, view, sequence, batch).await
}
async fn send_preprepare(
&self,
_cx: &Cx,
view: ViewNumber,
sequence: SequenceNumber,
batch: ConsensusBatch,
) -> Result<()> {
let digest = MessageDigest::of(&batch)?;
{
let mut state = self.state.lock().unwrap();
let entry = LogEntry {
batch: batch.clone(),
digest: digest.clone(),
view,
preprepared: true,
prepare_msgs: HashMap::new(),
commit_msgs: HashMap::new(),
result: None,
};
state.log.insert(sequence, entry);
}
let message = PbftMessage::PrePrepare {
view,
sequence,
digest,
batch,
};
timeout(
Time::from_millis(0),
self.config.preprepare_timeout,
self.transport.broadcast(message),
)
.await
.map_err(|_| Error::new(ErrorKind::DeadlineExceeded))?
}
pub async fn process_message(&self, cx: &Cx, message: PbftMessage) -> Result<()> {
match message {
PbftMessage::Request(request) => self.submit_request(cx, request).await,
PbftMessage::PrePrepare {
view,
sequence,
digest,
batch,
} => {
self.handle_preprepare(cx, view, sequence, digest, batch)
.await
}
PbftMessage::Prepare {
view,
sequence,
digest,
replica_id,
} => {
self.handle_prepare(cx, view, sequence, digest, replica_id)
.await
}
PbftMessage::Commit {
view,
sequence,
digest,
replica_id,
} => {
self.handle_commit(cx, view, sequence, digest, replica_id)
.await
}
PbftMessage::ViewChange {
new_view,
replica_id,
certificates,
} => {
self.handle_view_change(cx, new_view, replica_id, certificates)
.await
}
PbftMessage::NewView {
view,
view_change_msgs,
preprepare_msgs,
} => {
self.handle_new_view(cx, view, view_change_msgs, preprepare_msgs)
.await
}
}
}
async fn handle_preprepare(
&self,
_cx: &Cx,
view: ViewNumber,
sequence: SequenceNumber,
digest: MessageDigest,
batch: ConsensusBatch,
) -> Result<()> {
{
let state = self.state.lock().unwrap();
if view != state.view {
return Err(Error::new(ErrorKind::InvalidInput));
}
}
let computed_digest = MessageDigest::of(&batch)?;
if digest != computed_digest {
return Err(Error::new(ErrorKind::InvalidInput));
}
{
let mut state = self.state.lock().unwrap();
let entry = LogEntry {
batch,
digest: digest.clone(),
view,
preprepared: true,
prepare_msgs: HashMap::new(),
commit_msgs: HashMap::new(),
result: None,
};
state.log.insert(sequence, entry);
}
let prepare_msg = PbftMessage::Prepare {
view,
sequence,
digest,
replica_id: self.replica_id.clone(),
};
timeout(
Time::from_millis(0),
self.config.prepare_timeout,
self.transport.broadcast(prepare_msg),
)
.await
.map_err(|_| Error::new(ErrorKind::DeadlineExceeded))?
}
async fn handle_prepare(
&self,
_cx: &Cx,
view: ViewNumber,
sequence: SequenceNumber,
digest: MessageDigest,
replica_id: ReplicaId,
) -> Result<()> {
let should_commit = {
let mut state = self.state.lock().unwrap();
let entry = match state.log.get_mut(&sequence) {
Some(entry) if entry.view == view && entry.digest == digest => entry,
_ => return Ok(()), };
let msg = PbftMessage::Prepare {
view,
sequence,
digest: digest.clone(),
replica_id: replica_id.clone(),
};
entry.prepare_msgs.insert(replica_id, msg);
entry.prepare_msgs.len() + 1 >= self.config.quorum_size()
};
if should_commit {
let commit_msg = PbftMessage::Commit {
view,
sequence,
digest,
replica_id: self.replica_id.clone(),
};
timeout(
Time::from_millis(0),
self.config.commit_timeout,
self.transport.broadcast(commit_msg),
)
.await
.map_err(|_| Error::new(ErrorKind::DeadlineExceeded))??;
}
Ok(())
}
async fn handle_commit(
&self,
_cx: &Cx,
view: ViewNumber,
sequence: SequenceNumber,
digest: MessageDigest,
replica_id: ReplicaId,
) -> Result<()> {
let should_execute = {
let mut state = self.state.lock().unwrap();
let entry = match state.log.get_mut(&sequence) {
Some(entry) if entry.view == view && entry.digest == digest => entry,
_ => return Ok(()), };
let msg = PbftMessage::Commit {
view,
sequence,
digest: digest.clone(),
replica_id: replica_id.clone(),
};
entry.commit_msgs.insert(replica_id, msg);
entry.commit_msgs.len() + 1 >= self.config.quorum_size()
&& sequence == state.last_executed.next()
};
if should_execute {
self.execute_batch(sequence).await?;
}
Ok(())
}
async fn execute_batch(&self, sequence: SequenceNumber) -> Result<()> {
let batch = {
let mut state = self.state.lock().unwrap();
state.last_executed = sequence;
let entry = state.log.get_mut(&sequence).unwrap();
let batch = entry.batch.clone();
let result = Outcome::Ok(b"executed".to_vec());
entry.result = Some(result);
batch
};
let batch_size = batch.len();
#[cfg(feature = "tracing-integration")]
tracing::info!(
replica_id = %self.replica_id,
sequence = %sequence,
batch_size,
"Executed consensus batch"
);
#[cfg(not(feature = "tracing-integration"))]
let _ = batch_size;
Ok(())
}
async fn handle_view_change(
&self,
_cx: &Cx,
_new_view: ViewNumber,
_replica_id: ReplicaId,
_certificates: Vec<MessageCertificate>,
) -> Result<()> {
Ok(())
}
async fn handle_new_view(
&self,
_cx: &Cx,
_view: ViewNumber,
_view_change_msgs: Vec<PbftMessage>,
_preprepare_msgs: Vec<PbftMessage>,
) -> Result<()> {
Ok(())
}
}
pub struct PbftConsensus<T: PbftTransport> {
node: PbftNode<T>,
}
impl<T: PbftTransport> PbftConsensus<T> {
pub fn new(replica_id: ReplicaId, config: PbftConfig, transport: T) -> Result<Self> {
let node = PbftNode::new(replica_id, config, transport)?;
Ok(Self { node })
}
pub async fn submit(&self, cx: &Cx, request: ConsensusRequest) -> Result<ConsensusResponse> {
self.node.submit_request(cx, request.clone()).await?;
Ok(ConsensusResponse {
view: ViewNumber::new(0),
sequence: SequenceNumber::new(0),
result: Outcome::Ok(b"consensus result".to_vec()),
replica_id: self.node.replica_id.clone(),
timestamp: Time::from_millis(0),
})
}
pub async fn run(&self, cx: &Cx) -> Result<()> {
loop {
let message = self.node.transport.receive().await?;
self.node.process_message(cx, message).await?;
}
}
}