use alloc::{boxed::Box, collections::BTreeMap};
use core::fmt::{self, Debug};
use rand_core::CryptoRngCore;
use serde::{Deserialize, Serialize};
use crate::protocol::{
Artifact, BoxedFormat, BoxedRound, CommunicationInfo, DirectMessage, EchoBroadcast, EntryPoint, FinalizeOutcome,
LocalError, MessageValidationError, NormalBroadcast, PartyId, Payload, Protocol, ProtocolError, ProtocolMessage,
ProtocolValidationError, ReceiveError, RequiredMessages, Round, RoundId, TransitionInfo,
};
pub trait ChainedMarker {}
pub trait ChainedProtocol<Id>: 'static + Debug {
type Protocol1: Protocol<Id>;
type Protocol2: Protocol<Id>;
}
#[derive_where::derive_where(Debug, Clone)]
#[derive(Serialize, Deserialize)]
#[serde(bound(serialize = "
<C::Protocol1 as Protocol<Id>>::ProtocolError: Serialize,
<C::Protocol2 as Protocol<Id>>::ProtocolError: Serialize,
"))]
#[serde(bound(deserialize = "
<C::Protocol1 as Protocol<Id>>::ProtocolError: for<'x> Deserialize<'x>,
<C::Protocol2 as Protocol<Id>>::ProtocolError: for<'x> Deserialize<'x>,
"))]
pub enum ChainedProtocolError<Id, C>
where
C: ChainedProtocol<Id>,
{
Protocol1(<C::Protocol1 as Protocol<Id>>::ProtocolError),
Protocol2(<C::Protocol2 as Protocol<Id>>::ProtocolError),
}
impl<Id, C> fmt::Display for ChainedProtocolError<Id, C>
where
C: ChainedProtocol<Id>,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
match self {
Self::Protocol1(err) => write!(f, "Protocol 1: {err}"),
Self::Protocol2(err) => write!(f, "Protocol 2: {err}"),
}
}
}
impl<Id, C> ChainedProtocolError<Id, C>
where
C: ChainedProtocol<Id>,
{
fn from_protocol1(err: <C::Protocol1 as Protocol<Id>>::ProtocolError) -> Self {
Self::Protocol1(err)
}
fn from_protocol2(err: <C::Protocol2 as Protocol<Id>>::ProtocolError) -> Self {
Self::Protocol2(err)
}
}
#[derive_where::derive_where(Debug)]
pub struct ChainedAssociatedData<Id, C>
where
C: ChainedProtocol<Id>,
{
pub protocol1: <<C::Protocol1 as Protocol<Id>>::ProtocolError as ProtocolError<Id>>::AssociatedData,
pub protocol2: <<C::Protocol2 as Protocol<Id>>::ProtocolError as ProtocolError<Id>>::AssociatedData,
}
impl<Id, C> ProtocolError<Id> for ChainedProtocolError<Id, C>
where
C: ChainedProtocol<Id>,
{
type AssociatedData = ChainedAssociatedData<Id, C>;
fn required_messages(&self) -> RequiredMessages {
let (protocol_num, required_messages) = match self {
Self::Protocol1(err) => (1, err.required_messages()),
Self::Protocol2(err) => (2, err.required_messages()),
};
let previous_rounds = required_messages.previous_rounds.map(|previous_rounds| {
previous_rounds
.into_iter()
.map(|(round_id, required)| (round_id.group_under(protocol_num), required))
.collect()
});
let combined_echos = required_messages.combined_echos.map(|combined_echos| {
combined_echos
.into_iter()
.map(|round_id| round_id.group_under(protocol_num))
.collect()
});
RequiredMessages {
this_round: required_messages.this_round,
previous_rounds,
combined_echos,
}
}
#[allow(clippy::too_many_arguments)]
fn verify_messages_constitute_error(
&self,
format: &BoxedFormat,
guilty_party: &Id,
shared_randomness: &[u8],
associated_data: &Self::AssociatedData,
message: ProtocolMessage,
previous_messages: BTreeMap<RoundId, ProtocolMessage>,
combined_echos: BTreeMap<RoundId, BTreeMap<Id, EchoBroadcast>>,
) -> Result<(), ProtocolValidationError> {
let previous_messages = previous_messages
.into_iter()
.map(|(round_id, message)| round_id.split_group().map(|(_group_num, round_id)| (round_id, message)))
.collect::<Result<BTreeMap<_, _>, _>>()?;
let combined_echos = combined_echos
.into_iter()
.map(|(round_id, message)| round_id.split_group().map(|(_group_num, round_id)| (round_id, message)))
.collect::<Result<BTreeMap<_, _>, _>>()?;
match self {
Self::Protocol1(err) => err.verify_messages_constitute_error(
format,
guilty_party,
shared_randomness,
&associated_data.protocol1,
message,
previous_messages,
combined_echos,
),
Self::Protocol2(err) => err.verify_messages_constitute_error(
format,
guilty_party,
shared_randomness,
&associated_data.protocol2,
message,
previous_messages,
combined_echos,
),
}
}
}
impl<Id, C> Protocol<Id> for C
where
Id: 'static,
C: ChainedProtocol<Id> + ChainedMarker,
{
type Result = <C::Protocol2 as Protocol<Id>>::Result;
type ProtocolError = ChainedProtocolError<Id, C>;
fn verify_direct_message_is_invalid(
format: &BoxedFormat,
round_id: &RoundId,
message: &DirectMessage,
) -> Result<(), MessageValidationError> {
let (group, round_id) = round_id.split_group()?;
if group == 1 {
C::Protocol1::verify_direct_message_is_invalid(format, &round_id, message)
} else {
C::Protocol2::verify_direct_message_is_invalid(format, &round_id, message)
}
}
fn verify_echo_broadcast_is_invalid(
format: &BoxedFormat,
round_id: &RoundId,
message: &EchoBroadcast,
) -> Result<(), MessageValidationError> {
let (group, round_id) = round_id.split_group()?;
if group == 1 {
C::Protocol1::verify_echo_broadcast_is_invalid(format, &round_id, message)
} else {
C::Protocol2::verify_echo_broadcast_is_invalid(format, &round_id, message)
}
}
fn verify_normal_broadcast_is_invalid(
format: &BoxedFormat,
round_id: &RoundId,
message: &NormalBroadcast,
) -> Result<(), MessageValidationError> {
let (group, round_id) = round_id.split_group()?;
if group == 1 {
C::Protocol1::verify_normal_broadcast_is_invalid(format, &round_id, message)
} else {
C::Protocol2::verify_normal_broadcast_is_invalid(format, &round_id, message)
}
}
}
pub trait ChainedSplit<Id: PartyId> {
type Protocol: ChainedProtocol<Id> + ChainedMarker;
type EntryPoint: EntryPoint<Id, Protocol = <Self::Protocol as ChainedProtocol<Id>>::Protocol1>;
fn make_entry_point1(self) -> (Self::EntryPoint, impl ChainedJoin<Id, Protocol = Self::Protocol>);
}
pub trait ChainedJoin<Id: PartyId>: 'static + Debug + Send + Sync {
type Protocol: ChainedProtocol<Id> + ChainedMarker;
type EntryPoint: EntryPoint<Id, Protocol = <Self::Protocol as ChainedProtocol<Id>>::Protocol2>;
fn make_entry_point2(
self,
result: <<Self::Protocol as ChainedProtocol<Id>>::Protocol1 as Protocol<Id>>::Result,
) -> Self::EntryPoint;
}
impl<Id, T> EntryPoint<Id> for T
where
Id: PartyId,
T: ChainedSplit<Id> + ChainedMarker,
{
type Protocol = T::Protocol;
fn entry_round_id() -> RoundId {
<T as ChainedSplit<Id>>::EntryPoint::entry_round_id().group_under(1)
}
fn make_round(
self,
rng: &mut dyn CryptoRngCore,
shared_randomness: &[u8],
id: &Id,
) -> Result<BoxedRound<Id, Self::Protocol>, LocalError> {
let (entry_point, transition) = self.make_entry_point1();
let round = entry_point.make_round(rng, shared_randomness, id)?;
let chained_round = ChainedRound {
state: ChainState::Protocol1 {
id: id.clone(),
shared_randomness: shared_randomness.into(),
transition,
round,
},
};
Ok(BoxedRound::new_dynamic(chained_round))
}
}
#[derive(Debug)]
struct ChainedRound<Id, T>
where
Id: PartyId,
T: ChainedJoin<Id>,
{
state: ChainState<Id, T>,
}
#[derive_where::derive_where(Debug)]
enum ChainState<Id, T>
where
Id: PartyId,
T: ChainedJoin<Id>,
{
Protocol1 {
id: Id,
round: BoxedRound<Id, <T::Protocol as ChainedProtocol<Id>>::Protocol1>,
shared_randomness: Box<[u8]>,
transition: T,
},
Protocol2(BoxedRound<Id, <T::Protocol as ChainedProtocol<Id>>::Protocol2>),
}
impl<Id, T> Round<Id> for ChainedRound<Id, T>
where
Id: PartyId,
T: ChainedJoin<Id>,
{
type Protocol = T::Protocol;
fn transition_info(&self) -> TransitionInfo {
match &self.state {
ChainState::Protocol1 { round, .. } => {
let mut tinfo = round.as_ref().transition_info().group_under(1);
if tinfo.may_produce_result {
tinfo.may_produce_result = false;
tinfo.children.insert(T::EntryPoint::entry_round_id().group_under(2));
}
tinfo
}
ChainState::Protocol2(round) => round.as_ref().transition_info().group_under(2),
}
}
fn communication_info(&self) -> CommunicationInfo<Id> {
match &self.state {
ChainState::Protocol1 { round, .. } => round.as_ref().communication_info(),
ChainState::Protocol2(round) => round.as_ref().communication_info(),
}
}
fn make_direct_message(
&self,
rng: &mut dyn CryptoRngCore,
format: &BoxedFormat,
destination: &Id,
) -> Result<(DirectMessage, Option<Artifact>), LocalError> {
match &self.state {
ChainState::Protocol1 { round, .. } => round.as_ref().make_direct_message(rng, format, destination),
ChainState::Protocol2(round) => round.as_ref().make_direct_message(rng, format, destination),
}
}
fn make_echo_broadcast(
&self,
rng: &mut dyn CryptoRngCore,
format: &BoxedFormat,
) -> Result<EchoBroadcast, LocalError> {
match &self.state {
ChainState::Protocol1 { round, .. } => round.as_ref().make_echo_broadcast(rng, format),
ChainState::Protocol2(round) => round.as_ref().make_echo_broadcast(rng, format),
}
}
fn make_normal_broadcast(
&self,
rng: &mut dyn CryptoRngCore,
format: &BoxedFormat,
) -> Result<NormalBroadcast, LocalError> {
match &self.state {
ChainState::Protocol1 { round, .. } => round.as_ref().make_normal_broadcast(rng, format),
ChainState::Protocol2(round) => round.as_ref().make_normal_broadcast(rng, format),
}
}
fn receive_message(
&self,
format: &BoxedFormat,
from: &Id,
message: ProtocolMessage,
) -> Result<Payload, ReceiveError<Id, Self::Protocol>> {
match &self.state {
ChainState::Protocol1 { round, .. } => match round.as_ref().receive_message(format, from, message) {
Ok(payload) => Ok(payload),
Err(err) => Err(err.map(ChainedProtocolError::from_protocol1)),
},
ChainState::Protocol2(round) => match round.as_ref().receive_message(format, from, message) {
Ok(payload) => Ok(payload),
Err(err) => Err(err.map(ChainedProtocolError::from_protocol2)),
},
}
}
fn finalize(
self: Box<Self>,
rng: &mut dyn CryptoRngCore,
payloads: BTreeMap<Id, Payload>,
artifacts: BTreeMap<Id, Artifact>,
) -> Result<FinalizeOutcome<Id, Self::Protocol>, LocalError> {
match self.state {
ChainState::Protocol1 {
id,
round,
transition,
shared_randomness,
} => match round.into_boxed().finalize(rng, payloads, artifacts)? {
FinalizeOutcome::Result(result) => {
let entry_point2 = transition.make_entry_point2(result);
let round = entry_point2.make_round(rng, &shared_randomness, &id)?;
let chained_round = ChainedRound::<Id, T> {
state: ChainState::Protocol2(round),
};
Ok(FinalizeOutcome::AnotherRound(BoxedRound::new_dynamic(chained_round)))
}
FinalizeOutcome::AnotherRound(round) => {
let chained_round = ChainedRound::<Id, T> {
state: ChainState::Protocol1 {
id,
shared_randomness,
round,
transition,
},
};
Ok(FinalizeOutcome::AnotherRound(BoxedRound::new_dynamic(chained_round)))
}
},
ChainState::Protocol2(round) => match round.into_boxed().finalize(rng, payloads, artifacts)? {
FinalizeOutcome::Result(result) => Ok(FinalizeOutcome::Result(result)),
FinalizeOutcome::AnotherRound(round) => {
let chained_round = ChainedRound::<Id, T> {
state: ChainState::Protocol2(round),
};
Ok(FinalizeOutcome::AnotherRound(BoxedRound::new_dynamic(chained_round)))
}
},
}
}
}