use crate::{
codec::{decode_u32_items, encode_u32_items, CodecError, Decode, Encode, ParameterizedDecode},
vdaf::{Aggregator, PrepareTransition, VdafError},
};
use std::fmt::Debug;
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum PingPongError {
#[error("vdaf.prepare_init: {0}")]
VdafPrepareInit(VdafError),
#[error("vdaf.prepare_shares_to_prepare_message {0}")]
VdafPrepareSharesToPrepareMessage(VdafError),
#[error("vdaf.prepare_next {0}")]
VdafPrepareNext(VdafError),
#[error("encode/decode prep share {0}")]
CodecPrepShare(CodecError),
#[error("encode/decode prep message {0}")]
CodecPrepMessage(CodecError),
#[error("host state mismatch: in {found} expected {expected}")]
HostStateMismatch {
found: &'static str,
expected: &'static str,
},
#[error("peer message mismatch: message is {found} expected {expected}")]
PeerMessageMismatch {
found: &'static str,
expected: &'static str,
},
#[error("internal error: {0}")]
InternalError(&'static str),
}
#[derive(Clone, PartialEq, Eq)]
pub enum PingPongMessage {
Initialize {
prep_share: Vec<u8>,
},
Continue {
prep_msg: Vec<u8>,
prep_share: Vec<u8>,
},
Finish {
prep_msg: Vec<u8>,
},
}
impl PingPongMessage {
fn variant(&self) -> &'static str {
match self {
Self::Initialize { .. } => "Initialize",
Self::Continue { .. } => "Continue",
Self::Finish { .. } => "Finish",
}
}
}
impl Debug for PingPongMessage {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple(self.variant()).finish()
}
}
impl Encode for PingPongMessage {
fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
match self {
Self::Initialize { prep_share } => {
0u8.encode(bytes)?;
encode_u32_items(bytes, &(), prep_share)?;
}
Self::Continue {
prep_msg,
prep_share,
} => {
1u8.encode(bytes)?;
encode_u32_items(bytes, &(), prep_msg)?;
encode_u32_items(bytes, &(), prep_share)?;
}
Self::Finish { prep_msg } => {
2u8.encode(bytes)?;
encode_u32_items(bytes, &(), prep_msg)?;
}
}
Ok(())
}
fn encoded_len(&self) -> Option<usize> {
match self {
Self::Initialize { prep_share } => Some(1 + 4 + prep_share.len()),
Self::Continue {
prep_msg,
prep_share,
} => Some(1 + 4 + prep_msg.len() + 4 + prep_share.len()),
Self::Finish { prep_msg } => Some(1 + 4 + prep_msg.len()),
}
}
}
impl Decode for PingPongMessage {
fn decode(bytes: &mut std::io::Cursor<&[u8]>) -> Result<Self, CodecError> {
let message_type = u8::decode(bytes)?;
Ok(match message_type {
0 => {
let prep_share = decode_u32_items(&(), bytes)?;
Self::Initialize { prep_share }
}
1 => {
let prep_msg = decode_u32_items(&(), bytes)?;
let prep_share = decode_u32_items(&(), bytes)?;
Self::Continue {
prep_msg,
prep_share,
}
}
2 => {
let prep_msg = decode_u32_items(&(), bytes)?;
Self::Finish { prep_msg }
}
_ => return Err(CodecError::UnexpectedValue),
})
}
}
#[derive(Clone, Debug, Eq)]
pub struct PingPongTransition<
const VERIFY_KEY_SIZE: usize,
const NONCE_SIZE: usize,
A: Aggregator<VERIFY_KEY_SIZE, NONCE_SIZE>,
> {
previous_prepare_state: A::PrepareState,
current_prepare_message: A::PrepareMessage,
}
impl<
const VERIFY_KEY_SIZE: usize,
const NONCE_SIZE: usize,
A: Aggregator<VERIFY_KEY_SIZE, NONCE_SIZE>,
> PingPongTransition<VERIFY_KEY_SIZE, NONCE_SIZE, A>
{
#[allow(clippy::type_complexity)]
pub fn evaluate(
&self,
ctx: &[u8],
vdaf: &A,
) -> Result<
(
PingPongState<VERIFY_KEY_SIZE, NONCE_SIZE, A>,
PingPongMessage,
),
PingPongError,
> {
let prep_msg = self
.current_prepare_message
.get_encoded()
.map_err(PingPongError::CodecPrepMessage)?;
vdaf.prepare_next(
ctx,
self.previous_prepare_state.clone(),
self.current_prepare_message.clone(),
)
.map_err(PingPongError::VdafPrepareNext)
.and_then(|transition| match transition {
PrepareTransition::Continue(prep_state, prep_share) => Ok((
PingPongState::Continued(prep_state),
PingPongMessage::Continue {
prep_msg,
prep_share: prep_share
.get_encoded()
.map_err(PingPongError::CodecPrepShare)?,
},
)),
PrepareTransition::Finish(output_share) => Ok((
PingPongState::Finished(output_share),
PingPongMessage::Finish { prep_msg },
)),
})
}
}
impl<
const VERIFY_KEY_SIZE: usize,
const NONCE_SIZE: usize,
A: Aggregator<VERIFY_KEY_SIZE, NONCE_SIZE>,
> PartialEq for PingPongTransition<VERIFY_KEY_SIZE, NONCE_SIZE, A>
{
fn eq(&self, other: &Self) -> bool {
self.previous_prepare_state == other.previous_prepare_state
&& self.current_prepare_message == other.current_prepare_message
}
}
impl<const VERIFY_KEY_SIZE: usize, const NONCE_SIZE: usize, A> Encode
for PingPongTransition<VERIFY_KEY_SIZE, NONCE_SIZE, A>
where
A: Aggregator<VERIFY_KEY_SIZE, NONCE_SIZE>,
A::PrepareState: Encode,
{
fn encode(&self, bytes: &mut Vec<u8>) -> Result<(), CodecError> {
self.previous_prepare_state.encode(bytes)?;
self.current_prepare_message.encode(bytes)
}
fn encoded_len(&self) -> Option<usize> {
Some(
self.previous_prepare_state.encoded_len()?
+ self.current_prepare_message.encoded_len()?,
)
}
}
impl<const VERIFY_KEY_SIZE: usize, const NONCE_SIZE: usize, A, PrepareStateDecode>
ParameterizedDecode<PrepareStateDecode> for PingPongTransition<VERIFY_KEY_SIZE, NONCE_SIZE, A>
where
A: Aggregator<VERIFY_KEY_SIZE, NONCE_SIZE>,
A::PrepareState: ParameterizedDecode<PrepareStateDecode> + PartialEq,
A::PrepareMessage: PartialEq,
{
fn decode_with_param(
decoding_param: &PrepareStateDecode,
bytes: &mut std::io::Cursor<&[u8]>,
) -> Result<Self, CodecError> {
let previous_prepare_state = A::PrepareState::decode_with_param(decoding_param, bytes)?;
let current_prepare_message =
A::PrepareMessage::decode_with_param(&previous_prepare_state, bytes)?;
Ok(Self {
previous_prepare_state,
current_prepare_message,
})
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum PingPongState<
const VERIFY_KEY_SIZE: usize,
const NONCE_SIZE: usize,
A: Aggregator<VERIFY_KEY_SIZE, NONCE_SIZE>,
> {
Continued(A::PrepareState),
Finished(A::OutputShare),
}
#[derive(Clone, Debug)]
pub enum PingPongContinuedValue<
const VERIFY_KEY_SIZE: usize,
const NONCE_SIZE: usize,
A: Aggregator<VERIFY_KEY_SIZE, NONCE_SIZE>,
> {
WithMessage {
transition: PingPongTransition<VERIFY_KEY_SIZE, NONCE_SIZE, A>,
},
FinishedNoMessage {
output_share: A::OutputShare,
},
}
pub trait PingPongTopology<const VERIFY_KEY_SIZE: usize, const NONCE_SIZE: usize>:
Aggregator<VERIFY_KEY_SIZE, NONCE_SIZE>
{
type State;
type ContinuedValue;
type Transition;
fn leader_initialized(
&self,
verify_key: &[u8; VERIFY_KEY_SIZE],
ctx: &[u8],
agg_param: &Self::AggregationParam,
nonce: &[u8; NONCE_SIZE],
public_share: &Self::PublicShare,
input_share: &Self::InputShare,
) -> Result<(Self::State, PingPongMessage), PingPongError>;
#[allow(clippy::too_many_arguments)]
fn helper_initialized(
&self,
verify_key: &[u8; VERIFY_KEY_SIZE],
ctx: &[u8],
agg_param: &Self::AggregationParam,
nonce: &[u8; NONCE_SIZE],
public_share: &Self::PublicShare,
input_share: &Self::InputShare,
inbound: &PingPongMessage,
) -> Result<PingPongTransition<VERIFY_KEY_SIZE, NONCE_SIZE, Self>, PingPongError>;
fn leader_continued(
&self,
ctx: &[u8],
leader_state: Self::State,
agg_param: &Self::AggregationParam,
inbound: &PingPongMessage,
) -> Result<Self::ContinuedValue, PingPongError>;
fn helper_continued(
&self,
ctx: &[u8],
helper_state: Self::State,
agg_param: &Self::AggregationParam,
inbound: &PingPongMessage,
) -> Result<Self::ContinuedValue, PingPongError>;
}
trait PingPongTopologyPrivate<const VERIFY_KEY_SIZE: usize, const NONCE_SIZE: usize>:
PingPongTopology<VERIFY_KEY_SIZE, NONCE_SIZE>
{
fn continued(
&self,
ctx: &[u8],
is_leader: bool,
host_state: Self::State,
agg_param: &Self::AggregationParam,
inbound: &PingPongMessage,
) -> Result<Self::ContinuedValue, PingPongError>;
}
impl<const VERIFY_KEY_SIZE: usize, const NONCE_SIZE: usize, A>
PingPongTopology<VERIFY_KEY_SIZE, NONCE_SIZE> for A
where
A: Aggregator<VERIFY_KEY_SIZE, NONCE_SIZE>,
{
type State = PingPongState<VERIFY_KEY_SIZE, NONCE_SIZE, Self>;
type ContinuedValue = PingPongContinuedValue<VERIFY_KEY_SIZE, NONCE_SIZE, Self>;
type Transition = PingPongTransition<VERIFY_KEY_SIZE, NONCE_SIZE, Self>;
fn leader_initialized(
&self,
verify_key: &[u8; VERIFY_KEY_SIZE],
ctx: &[u8],
agg_param: &Self::AggregationParam,
nonce: &[u8; NONCE_SIZE],
public_share: &Self::PublicShare,
input_share: &Self::InputShare,
) -> Result<(Self::State, PingPongMessage), PingPongError> {
self.prepare_init(
verify_key,
ctx,
0,
agg_param,
nonce,
public_share,
input_share,
)
.map_err(PingPongError::VdafPrepareInit)
.and_then(|(prep_state, prep_share)| {
Ok((
PingPongState::Continued(prep_state),
PingPongMessage::Initialize {
prep_share: prep_share
.get_encoded()
.map_err(PingPongError::CodecPrepShare)?,
},
))
})
}
fn helper_initialized(
&self,
verify_key: &[u8; VERIFY_KEY_SIZE],
ctx: &[u8],
agg_param: &Self::AggregationParam,
nonce: &[u8; NONCE_SIZE],
public_share: &Self::PublicShare,
input_share: &Self::InputShare,
inbound: &PingPongMessage,
) -> Result<Self::Transition, PingPongError> {
let (prep_state, prep_share) = self
.prepare_init(
verify_key,
ctx,
1,
agg_param,
nonce,
public_share,
input_share,
)
.map_err(PingPongError::VdafPrepareInit)?;
let inbound_prep_share = if let PingPongMessage::Initialize { prep_share } = inbound {
Self::PrepareShare::get_decoded_with_param(&prep_state, prep_share)
.map_err(PingPongError::CodecPrepShare)?
} else {
return Err(PingPongError::PeerMessageMismatch {
found: inbound.variant(),
expected: "initialize",
});
};
let current_prepare_message = self
.prepare_shares_to_prepare_message(ctx, agg_param, [inbound_prep_share, prep_share])
.map_err(PingPongError::VdafPrepareSharesToPrepareMessage)?;
Ok(PingPongTransition {
previous_prepare_state: prep_state,
current_prepare_message,
})
}
fn leader_continued(
&self,
ctx: &[u8],
leader_state: Self::State,
agg_param: &Self::AggregationParam,
inbound: &PingPongMessage,
) -> Result<Self::ContinuedValue, PingPongError> {
self.continued(ctx, true, leader_state, agg_param, inbound)
}
fn helper_continued(
&self,
ctx: &[u8],
helper_state: Self::State,
agg_param: &Self::AggregationParam,
inbound: &PingPongMessage,
) -> Result<Self::ContinuedValue, PingPongError> {
self.continued(ctx, false, helper_state, agg_param, inbound)
}
}
impl<const VERIFY_KEY_SIZE: usize, const NONCE_SIZE: usize, A>
PingPongTopologyPrivate<VERIFY_KEY_SIZE, NONCE_SIZE> for A
where
A: Aggregator<VERIFY_KEY_SIZE, NONCE_SIZE>,
{
fn continued(
&self,
ctx: &[u8],
is_leader: bool,
host_state: Self::State,
agg_param: &Self::AggregationParam,
inbound: &PingPongMessage,
) -> Result<Self::ContinuedValue, PingPongError> {
let host_prep_state = if let PingPongState::Continued(state) = host_state {
state
} else {
return Err(PingPongError::HostStateMismatch {
found: "finished",
expected: "continue",
});
};
let (prep_msg, next_peer_prep_share) = match inbound {
PingPongMessage::Initialize { .. } => {
return Err(PingPongError::PeerMessageMismatch {
found: inbound.variant(),
expected: "continue",
});
}
PingPongMessage::Continue {
prep_msg,
prep_share,
} => (prep_msg, Some(prep_share)),
PingPongMessage::Finish { prep_msg } => (prep_msg, None),
};
let prep_msg = Self::PrepareMessage::get_decoded_with_param(&host_prep_state, prep_msg)
.map_err(PingPongError::CodecPrepMessage)?;
let host_prep_transition = self
.prepare_next(ctx, host_prep_state, prep_msg)
.map_err(PingPongError::VdafPrepareNext)?;
match (host_prep_transition, next_peer_prep_share) {
(
PrepareTransition::Continue(next_prep_state, next_host_prep_share),
Some(next_peer_prep_share),
) => {
let next_peer_prep_share = Self::PrepareShare::get_decoded_with_param(
&next_prep_state,
next_peer_prep_share,
)
.map_err(PingPongError::CodecPrepShare)?;
let mut prep_shares = [next_peer_prep_share, next_host_prep_share];
if is_leader {
prep_shares.reverse();
}
let current_prepare_message = self
.prepare_shares_to_prepare_message(ctx, agg_param, prep_shares)
.map_err(PingPongError::VdafPrepareSharesToPrepareMessage)?;
Ok(PingPongContinuedValue::WithMessage {
transition: PingPongTransition {
previous_prepare_state: next_prep_state,
current_prepare_message,
},
})
}
(PrepareTransition::Finish(output_share), None) => {
Ok(PingPongContinuedValue::FinishedNoMessage { output_share })
}
(PrepareTransition::Continue(_, _), None) => Err(PingPongError::PeerMessageMismatch {
found: inbound.variant(),
expected: "continue",
}),
(PrepareTransition::Finish(_), Some(_)) => Err(PingPongError::PeerMessageMismatch {
found: inbound.variant(),
expected: "finish",
}),
}
}
}
#[cfg(test)]
mod tests {
use std::io::Cursor;
use super::*;
use crate::vdaf::dummy;
use assert_matches::assert_matches;
const CTX_STR: &[u8] = b"pingpong ctx";
#[test]
fn ping_pong_one_round() {
let verify_key = [];
let aggregation_param = dummy::AggregationParam(0);
let nonce = [0; 16];
#[allow(clippy::let_unit_value)]
let public_share = ();
let input_share = dummy::InputShare(0);
let leader = dummy::Vdaf::new(1);
let helper = dummy::Vdaf::new(1);
let (leader_state, leader_message) = leader
.leader_initialized(
&verify_key,
CTX_STR,
&aggregation_param,
&nonce,
&public_share,
&input_share,
)
.unwrap();
let (helper_state, helper_message) = helper
.helper_initialized(
&verify_key,
CTX_STR,
&aggregation_param,
&nonce,
&public_share,
&input_share,
&leader_message,
)
.unwrap()
.evaluate(CTX_STR, &helper)
.unwrap();
assert_matches!(helper_state, PingPongState::Finished(_));
let leader_state = leader
.leader_continued(CTX_STR, leader_state, &aggregation_param, &helper_message)
.unwrap();
assert_matches!(
leader_state,
PingPongContinuedValue::FinishedNoMessage { .. }
);
}
#[test]
fn ping_pong_two_rounds() {
let verify_key = [];
let aggregation_param = dummy::AggregationParam(0);
let nonce = [0; 16];
#[allow(clippy::let_unit_value)]
let public_share = ();
let input_share = dummy::InputShare(0);
let leader = dummy::Vdaf::new(2);
let helper = dummy::Vdaf::new(2);
let (leader_state, leader_message) = leader
.leader_initialized(
&verify_key,
CTX_STR,
&aggregation_param,
&nonce,
&public_share,
&input_share,
)
.unwrap();
let (helper_state, helper_message) = helper
.helper_initialized(
&verify_key,
CTX_STR,
&aggregation_param,
&nonce,
&public_share,
&input_share,
&leader_message,
)
.unwrap()
.evaluate(CTX_STR, &helper)
.unwrap();
assert_matches!(helper_state, PingPongState::Continued(_));
let leader_state = leader
.leader_continued(CTX_STR, leader_state, &aggregation_param, &helper_message)
.unwrap();
let leader_message = assert_matches!(
leader_state, PingPongContinuedValue::WithMessage { transition } => {
let (state, message) = transition.evaluate(CTX_STR,&leader).unwrap();
assert_matches!(state, PingPongState::Finished(_));
message
}
);
let helper_state = helper
.helper_continued(CTX_STR, helper_state, &aggregation_param, &leader_message)
.unwrap();
assert_matches!(
helper_state,
PingPongContinuedValue::FinishedNoMessage { .. }
);
}
#[test]
fn ping_pong_three_rounds() {
let verify_key = [];
let aggregation_param = dummy::AggregationParam(0);
let nonce = [0; 16];
#[allow(clippy::let_unit_value)]
let public_share = ();
let input_share = dummy::InputShare(0);
let leader = dummy::Vdaf::new(3);
let helper = dummy::Vdaf::new(3);
let (leader_state, leader_message) = leader
.leader_initialized(
&verify_key,
CTX_STR,
&aggregation_param,
&nonce,
&public_share,
&input_share,
)
.unwrap();
let (helper_state, helper_message) = helper
.helper_initialized(
&verify_key,
CTX_STR,
&aggregation_param,
&nonce,
&public_share,
&input_share,
&leader_message,
)
.unwrap()
.evaluate(CTX_STR, &helper)
.unwrap();
assert_matches!(helper_state, PingPongState::Continued(_));
let leader_state = leader
.leader_continued(CTX_STR, leader_state, &aggregation_param, &helper_message)
.unwrap();
let (leader_state, leader_message) = assert_matches!(
leader_state, PingPongContinuedValue::WithMessage { transition } => {
let (state, message) = transition.evaluate(CTX_STR,&leader).unwrap();
assert_matches!(state, PingPongState::Continued(_));
(state, message)
}
);
let helper_state = helper
.helper_continued(CTX_STR, helper_state, &aggregation_param, &leader_message)
.unwrap();
let helper_message = assert_matches!(
helper_state, PingPongContinuedValue::WithMessage { transition } => {
let (state, message) = transition.evaluate(CTX_STR,&helper).unwrap();
assert_matches!(state, PingPongState::Finished(_));
message
}
);
let leader_state = leader
.leader_continued(CTX_STR, leader_state, &aggregation_param, &helper_message)
.unwrap();
assert_matches!(
leader_state,
PingPongContinuedValue::FinishedNoMessage { .. }
);
}
#[test]
fn roundtrip_message() {
let messages = [
(
PingPongMessage::Initialize {
prep_share: Vec::from("prepare share"),
},
concat!(
"00", concat!(
"0000000d", "70726570617265207368617265", ),
),
),
(
PingPongMessage::Continue {
prep_msg: Vec::from("prepare message"),
prep_share: Vec::from("prepare share"),
},
concat!(
"01", concat!(
"0000000f", "70726570617265206d657373616765", ),
concat!(
"0000000d", "70726570617265207368617265", ),
),
),
(
PingPongMessage::Finish {
prep_msg: Vec::from("prepare message"),
},
concat!(
"02", concat!(
"0000000f", "70726570617265206d657373616765", ),
),
),
];
for (message, expected_hex) in messages {
let mut encoded_val = Vec::new();
message.encode(&mut encoded_val).unwrap();
let got_hex = hex::encode(&encoded_val);
assert_eq!(
&got_hex, expected_hex,
"Couldn't roundtrip (encoded value differs): {message:?}",
);
let decoded_val = PingPongMessage::decode(&mut Cursor::new(&encoded_val)).unwrap();
assert_eq!(
decoded_val, message,
"Couldn't roundtrip (decoded value differs): {message:?}"
);
assert_eq!(
encoded_val.len(),
message.encoded_len().expect("No encoded length hint"),
"Encoded length hint is incorrect: {message:?}"
)
}
}
#[test]
fn roundtrip_transition() {
let transition = PingPongTransition::<0, 16, dummy::Vdaf> {
previous_prepare_state: dummy::PrepareState::default(),
current_prepare_message: (),
};
let encoded = transition.get_encoded().unwrap();
let hex_encoded = hex::encode(&encoded);
assert_eq!(
hex_encoded,
concat!(
concat!(
"00", "00000000", ),
)
);
let decoded = PingPongTransition::get_decoded_with_param(&(), &encoded).unwrap();
assert_eq!(transition, decoded);
assert_eq!(
encoded.len(),
transition.encoded_len().expect("No encoded length hint"),
);
}
}