#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum CommunicationReplayError {
SequenceMismatch {
expected: u64,
actual: u64,
},
DuplicateIdentity {
nullifier: Nullifier,
},
}
impl CommunicationReplayError {
#[must_use]
pub fn tag(&self) -> &'static str {
match self {
Self::SequenceMismatch { .. } => COMM_REPLAY_SEQUENCE_MISMATCH_TAG,
Self::DuplicateIdentity { .. } => COMM_REPLAY_DUPLICATE_TAG,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct CommunicationConsumeResult {
pub mode: CommunicationReplayMode,
pub pre_root: Hash,
pub post_root: Hash,
pub consumed_nullifier: Option<Nullifier>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct CommunicationConsumptionArtifact {
pub tick: u64,
pub identity: CommunicationIdentity,
pub mode: CommunicationReplayMode,
pub pre_root: Hash,
pub post_root: Hash,
}
pub trait CommunicationConsumption {
fn mode(&self) -> CommunicationReplayMode;
fn set_mode(&mut self, mode: CommunicationReplayMode);
fn state(&self) -> &CommunicationReplayState;
fn root(&self) -> Hash;
fn allocate_send_sequence(&mut self, edge: &Edge) -> u64;
fn consume_receive(
&mut self,
identity: &CommunicationIdentity,
) -> Result<CommunicationConsumeResult, CommunicationReplayError>;
fn prune_session(&mut self, sid: SessionId);
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
pub struct DefaultCommunicationConsumption {
#[serde(default)]
pub mode: CommunicationReplayMode,
#[serde(default)]
pub state: CommunicationReplayState,
#[serde(skip)]
root_cache: ReplayRootCache,
}
impl DefaultCommunicationConsumption {
#[must_use]
pub fn new(mode: CommunicationReplayMode) -> Self {
let mut model = Self {
mode,
state: CommunicationReplayState::default(),
root_cache: ReplayRootCache::default(),
};
model.rebuild_root_cache();
model
}
#[must_use]
pub fn root(&self) -> Hash {
self.root_cache.root()
}
fn rebuild_root_cache(&mut self) {
self.root_cache = ReplayRootCache::from_state(&self.state);
}
}
fn identity_nullifier(identity: &CommunicationIdentity) -> Nullifier {
let bytes = replay_binary_encode(identity);
Nullifier(DefaultVerificationModel::hash(HashTag::Nullifier, &bytes))
}
#[derive(Debug, Deserialize)]
struct DefaultCommunicationConsumptionSerde {
#[serde(default)]
mode: CommunicationReplayMode,
#[serde(default)]
state: CommunicationReplayState,
}
impl<'de> Deserialize<'de> for DefaultCommunicationConsumption {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let raw = DefaultCommunicationConsumptionSerde::deserialize(deserializer)?;
let mut model = Self {
mode: raw.mode,
state: raw.state,
root_cache: ReplayRootCache::default(),
};
model.rebuild_root_cache();
Ok(model)
}
}
impl CommunicationConsumption for DefaultCommunicationConsumption {
fn mode(&self) -> CommunicationReplayMode {
self.mode
}
fn set_mode(&mut self, mode: CommunicationReplayMode) {
self.mode = mode;
}
fn state(&self) -> &CommunicationReplayState {
&self.state
}
fn root(&self) -> Hash {
DefaultCommunicationConsumption::root(self)
}
fn allocate_send_sequence(&mut self, edge: &Edge) -> u64 {
let previous = self.state.next_send_sequence.get(edge).copied();
let entry = self
.state
.next_send_sequence
.entry(edge.clone())
.or_insert(0);
let sequence_no = *entry;
*entry = entry.saturating_add(1);
self.root_cache.update_send_sequence(edge, previous, *entry);
sequence_no
}
fn consume_receive(
&mut self,
identity: &CommunicationIdentity,
) -> Result<CommunicationConsumeResult, CommunicationReplayError> {
let pre_root = self.root();
let consumed_nullifier = match self.mode {
CommunicationReplayMode::Off => None,
CommunicationReplayMode::Sequence => {
let edge = identity.edge();
let expected = self
.state
.next_recv_sequence
.get(&edge)
.copied()
.unwrap_or(0);
if identity.sequence_no != expected {
return Err(CommunicationReplayError::SequenceMismatch {
expected,
actual: identity.sequence_no,
});
}
let previous = self.state.next_recv_sequence.get(&edge).copied();
self.state
.next_recv_sequence
.insert(edge, expected.saturating_add(1));
self.root_cache.update_recv_sequence(
&identity.edge(),
previous,
expected.saturating_add(1),
);
None
}
CommunicationReplayMode::Nullifier => {
let nullifier = identity_nullifier(identity);
if self.state.consumed_nullifiers.contains(&nullifier) {
return Err(CommunicationReplayError::DuplicateIdentity { nullifier });
}
self.state.consumed_nullifiers.insert(nullifier);
self.root_cache.insert_nullifier(nullifier);
Some(nullifier)
}
};
let post_root = self.root();
Ok(CommunicationConsumeResult {
mode: self.mode,
pre_root,
post_root,
consumed_nullifier,
})
}
fn prune_session(&mut self, sid: SessionId) {
let send_pruned: Vec<_> = self
.state
.next_send_sequence
.iter()
.filter_map(|(edge, sequence_no)| {
(edge.sid == sid).then_some((edge.clone(), *sequence_no))
})
.collect();
let recv_pruned: Vec<_> = self
.state
.next_recv_sequence
.iter()
.filter_map(|(edge, sequence_no)| {
(edge.sid == sid).then_some((edge.clone(), *sequence_no))
})
.collect();
self.state
.next_send_sequence
.retain(|edge, _| edge.sid != sid);
self.state
.next_recv_sequence
.retain(|edge, _| edge.sid != sid);
for (edge, sequence_no) in send_pruned {
self.root_cache.remove_send_sequence(&edge, sequence_no);
}
for (edge, sequence_no) in recv_pruned {
self.root_cache.remove_recv_sequence(&edge, sequence_no);
}
}
}