use std::collections::BTreeMap;
use std::sync::Arc;
use crate::crypto::{self, Ciphertext, DecryptionShare};
use failure::Fail;
use rand::Rng;
use rand_derive::Rand;
use serde::{Deserialize, Serialize};
use crate::fault_log::{self, Fault};
use crate::{ConsensusProtocol, NetworkInfo, NodeIdT, Target};
#[derive(Clone, Eq, PartialEq, Debug, Fail)]
pub enum Error {
#[fail(display = "Redundant input provided: {:?}", _0)]
MultipleInputs(Box<Ciphertext>),
#[fail(display = "Invalid ciphertext: {:?}", _0)]
InvalidCiphertext(Box<Ciphertext>),
#[fail(display = "Unknown sender")]
UnknownSender,
#[fail(display = "Decryption failed: {:?}", _0)]
Decryption(crypto::error::Error),
#[fail(display = "Tried to decrypt before setting ciphertext")]
CiphertextIsNone,
}
pub type Result<T> = ::std::result::Result<T, Error>;
#[derive(Clone, Debug, Fail, PartialEq)]
pub enum FaultKind {
#[fail(display = "`ThresholdDecrypt` received multiple shares from the same sender.")]
MultipleDecryptionShares,
#[fail(display = "`HoneyBadger` received a decryption share from an unverified sender.")]
UnverifiedDecryptionShareSender,
}
pub type FaultLog<N> = fault_log::FaultLog<N, FaultKind>;
#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Rand)]
pub struct Message(pub DecryptionShare);
#[derive(Debug)]
pub struct ThresholdDecrypt<N> {
netinfo: Arc<NetworkInfo<N>>,
ciphertext: Option<Ciphertext>,
shares: BTreeMap<N, (usize, DecryptionShare)>,
had_input: bool,
terminated: bool,
}
pub type Step<N> = crate::CpStep<ThresholdDecrypt<N>>;
impl<N: NodeIdT> ConsensusProtocol for ThresholdDecrypt<N> {
type NodeId = N;
type Input = ();
type Output = Vec<u8>;
type Message = Message;
type Error = Error;
type FaultKind = FaultKind;
fn handle_input<R: Rng>(&mut self, _input: (), _rng: &mut R) -> Result<Step<N>> {
self.start_decryption()
}
fn handle_message<R: Rng>(
&mut self,
sender_id: &Self::NodeId,
message: Message,
_rng: &mut R,
) -> Result<Step<N>> {
self.handle_message(sender_id, message)
}
fn terminated(&self) -> bool {
self.terminated
}
fn our_id(&self) -> &N {
self.netinfo.our_id()
}
}
impl<N: NodeIdT> ThresholdDecrypt<N> {
pub fn new(netinfo: Arc<NetworkInfo<N>>) -> Self {
ThresholdDecrypt {
netinfo,
ciphertext: None,
shares: BTreeMap::new(),
had_input: false,
terminated: false,
}
}
pub fn new_with_ciphertext(netinfo: Arc<NetworkInfo<N>>, ct: Ciphertext) -> Result<Self> {
let mut td = ThresholdDecrypt::new(netinfo);
td.set_ciphertext(ct)?;
Ok(td)
}
pub fn set_ciphertext(&mut self, ct: Ciphertext) -> Result<()> {
if self.ciphertext.is_some() {
return Err(Error::MultipleInputs(Box::new(ct)));
}
if !ct.verify() {
return Err(Error::InvalidCiphertext(Box::new(ct.clone())));
}
self.ciphertext = Some(ct);
Ok(())
}
pub fn start_decryption(&mut self) -> Result<Step<N>> {
if self.had_input {
return Ok(Step::default()); }
let ct = self.ciphertext.clone().ok_or(Error::CiphertextIsNone)?;
let mut step = Step::default();
step.fault_log.extend(self.remove_invalid_shares());
self.had_input = true;
let opt_idx = self.netinfo.node_index(self.our_id());
let (idx, share) = match (opt_idx, self.netinfo.secret_key_share()) {
(Some(idx), Some(sks)) => (idx, sks.decrypt_share_no_verify(&ct)),
(_, _) => return Ok(step.join(self.try_output()?)), };
let our_id = self.our_id().clone();
let msg = Target::All.message(Message(share.clone()));
step.messages.push(msg);
self.shares.insert(our_id, (idx, share));
step.extend(self.try_output()?);
Ok(step)
}
pub fn sender_ids(&self) -> impl Iterator<Item = &N> {
self.shares.keys()
}
pub fn handle_message(&mut self, sender_id: &N, message: Message) -> Result<Step<N>> {
if self.terminated {
return Ok(Step::default()); }
let idx = self
.netinfo
.node_index(sender_id)
.ok_or(Error::UnknownSender)?;
let Message(share) = message;
if !self.is_share_valid(sender_id, &share) {
let fault_kind = FaultKind::UnverifiedDecryptionShareSender;
return Ok(Fault::new(sender_id.clone(), fault_kind).into());
}
let entry = (idx, share);
if self.shares.insert(sender_id.clone(), entry).is_some() {
return Ok(Fault::new(sender_id.clone(), FaultKind::MultipleDecryptionShares).into());
}
self.try_output()
}
fn remove_invalid_shares(&mut self) -> FaultLog<N> {
let faulty_senders: Vec<N> = self
.shares
.iter()
.filter(|(id, (_, share))| !self.is_share_valid(id, share))
.map(|(id, _)| id.clone())
.collect();
let mut fault_log = FaultLog::default();
for id in faulty_senders {
self.shares.remove(&id);
fault_log.append(id, FaultKind::UnverifiedDecryptionShareSender);
}
fault_log
}
fn is_share_valid(&self, id: &N, share: &DecryptionShare) -> bool {
let ct = match self.ciphertext {
None => return true, Some(ref ct) => ct,
};
match self.netinfo.public_key_share(id) {
None => false, Some(pk) => pk.verify_decryption_share(share, ct),
}
}
fn try_output(&mut self) -> Result<Step<N>> {
if self.terminated || self.shares.len() <= self.netinfo.num_faulty() {
return Ok(Step::default()); }
let ct = match self.ciphertext {
None => return Ok(Step::default()), Some(ref ct) => ct.clone(),
};
self.terminated = true;
let step = self.start_decryption()?; let share_itr = self
.shares
.values()
.map(|&(ref idx, ref share)| (idx, share));
let plaintext = self
.netinfo
.public_key_set()
.decrypt(share_itr, &ct)
.map_err(Error::Decryption)?;
Ok(step.with_output(plaintext))
}
}