use digest::Digest;
use futures::SinkExt;
use paillier_zk::{
no_small_factor::non_interactive as π_fac,
paillier_blum_modulus as π_mod,
rug::{Complete, Integer},
IntegerExt,
};
use rand_core::{CryptoRng, RngCore};
use round_based::{
rounds_router::{simple_store::RoundInput, RoundsRouter},
Delivery, Mpc, MpcParty, Outgoing, ProtocolMessage,
};
use serde::{Deserialize, Serialize};
use crate::{
errors::IoError,
key_share::{AuxInfo, DirtyAuxInfo, PartyAux, Validate},
progress::Tracer,
security_level::SecurityLevel,
utils,
utils::{collect_blame, AbortBlame},
zk::ring_pedersen_parameters as π_prm,
ExecutionId,
};
use super::{Bug, KeyRefreshError, PregeneratedPrimes, ProtocolAborted};
macro_rules! prefixed {
($name:tt) => {
concat!("dfns.cggmp21.aux_gen.", $name)
};
}
#[derive(ProtocolMessage, Clone, Serialize, Deserialize)]
#[serde(bound = "")]
#[allow(clippy::large_enum_variant)]
pub enum Msg<D: Digest, L: SecurityLevel> {
Round1(MsgRound1<D>),
Round2(MsgRound2<L>),
Round3(MsgRound3),
ReliabilityCheck(MsgReliabilityCheck<D>),
}
#[derive(Clone, Serialize, Deserialize, udigest::Digestable)]
#[udigest(tag = prefixed!("round1"))]
#[udigest(bound = "")]
#[serde(bound = "")]
pub struct MsgRound1<D: Digest> {
#[udigest(as_bytes)]
pub commitment: digest::Output<D>,
}
#[derive(Clone, Serialize, Deserialize, udigest::Digestable)]
#[udigest(tag = prefixed!("round2"))]
#[udigest(bound = "")]
#[serde(bound = "")]
pub struct MsgRound2<L: SecurityLevel> {
#[udigest(as = utils::encoding::Integer)]
pub N: Integer,
#[udigest(as = utils::encoding::Integer)]
pub s: Integer,
#[udigest(as = utils::encoding::Integer)]
pub t: Integer,
pub params_proof: π_prm::Proof<{ crate::security_level::M }>,
#[serde(with = "hex")]
#[udigest(as_bytes)]
pub rho_bytes: L::Rid,
#[serde(with = "hex")]
#[udigest(as_bytes)]
pub decommit: L::Rid,
}
#[derive(Clone, Serialize, Deserialize)]
pub struct MsgRound3 {
pub mod_proof: (
π_mod::Commitment,
π_mod::Proof<{ crate::security_level::M }>,
),
pub fac_proof: π_fac::Proof,
}
#[derive(Clone, Serialize, Deserialize)]
#[serde(bound = "")]
pub struct MsgReliabilityCheck<D: Digest>(pub digest::Output<D>);
mod unambiguous {
use digest::Digest;
use crate::{ExecutionId, SecurityLevel};
#[derive(udigest::Digestable)]
#[udigest(tag = prefixed!("proof_prm"))]
pub struct ProofPrm<'a> {
pub sid: ExecutionId<'a>,
pub prover: u16,
}
#[derive(udigest::Digestable)]
#[udigest(tag = prefixed!("proof_mod"))]
pub struct ProofMod<'a> {
pub sid: ExecutionId<'a>,
#[udigest(as_bytes)]
pub rho: &'a [u8],
pub prover: u16,
}
#[derive(udigest::Digestable)]
#[udigest(tag = prefixed!("proof_fac"))]
#[udigest(bound = "")]
pub struct ProofFac<'a> {
pub sid: ExecutionId<'a>,
#[udigest(as_bytes)]
pub rho: &'a [u8],
pub prover: u16,
}
#[derive(udigest::Digestable)]
#[udigest(tag = prefixed!("hash_commitment"))]
#[udigest(bound = "")]
pub struct HashCom<'a, L: SecurityLevel> {
pub sid: ExecutionId<'a>,
pub prover: u16,
pub decommitment: &'a super::MsgRound2<L>,
}
#[derive(udigest::Digestable)]
#[udigest(tag = prefixed!("echo_round"))]
#[udigest(bound = "")]
pub struct Echo<'a, D: Digest> {
pub sid: ExecutionId<'a>,
pub commitment: &'a super::MsgRound1<D>,
}
}
pub async fn run_aux_gen<R, M, L, D>(
i: u16,
n: u16,
mut rng: &mut R,
party: M,
sid: ExecutionId<'_>,
pregenerated: PregeneratedPrimes<L>,
mut tracer: Option<&mut dyn Tracer>,
reliable_broadcast_enforced: bool,
compute_multiexp_table: bool,
compute_crt: bool,
) -> Result<AuxInfo<L>, KeyRefreshError>
where
R: RngCore + CryptoRng,
M: Mpc<ProtocolMessage = Msg<D, L>>,
L: SecurityLevel,
D: Digest<OutputSize = digest::typenum::U32> + Clone + 'static,
{
tracer.protocol_begins();
tracer.stage("Retrieve auxiliary data");
tracer.stage("Setup networking");
let MpcParty { delivery, .. } = party.into_party();
let (incomings, mut outgoings) = delivery.split();
let mut rounds = RoundsRouter::<Msg<D, L>>::builder();
let round1 = rounds.add_round(RoundInput::<MsgRound1<D>>::broadcast(i, n));
let round1_sync = rounds.add_round(RoundInput::<MsgReliabilityCheck<D>>::broadcast(i, n));
let round2 = rounds.add_round(RoundInput::<MsgRound2<L>>::broadcast(i, n));
let round3 = rounds.add_round(RoundInput::<MsgRound3>::p2p(i, n));
let mut rounds = rounds.listen(incomings);
tracer.round_begins();
tracer.stage("Retrieve primes (p and q)");
let PregeneratedPrimes { p, q, .. } = pregenerated;
tracer.stage("Compute paillier decryption key (N)");
let N = (&p * &q).complete();
let phi_N = (&p - 1u8).complete() * (&q - 1u8).complete();
tracer.stage("Generate auxiliary params r, λ, t, s");
let r = Integer::gen_invertible(&N, rng);
let lambda = phi_N
.random_below_ref(&mut utils::external_rand(rng))
.into();
let t = r.square().modulo(&N);
let s = t.pow_mod_ref(&lambda, &N).ok_or(Bug::PowMod)?.into();
tracer.stage("Prove Πprm (ψˆ_i)");
let hat_psi = π_prm::prove::<{ crate::security_level::M }, D>(
&unambiguous::ProofPrm { sid, prover: i },
&mut rng,
π_prm::Data {
N: &N,
s: &s,
t: &t,
},
&phi_N,
&lambda,
)
.map_err(Bug::PiPrm)?;
tracer.stage("Sample random bytes");
let mut rho_bytes = L::Rid::default();
rng.fill_bytes(rho_bytes.as_mut());
tracer.stage("Compute hash commitment and sample decommitment");
let decommitment = MsgRound2 {
N: N.clone(),
s: s.clone(),
t: t.clone(),
params_proof: hat_psi,
rho_bytes: rho_bytes.clone(),
decommit: {
let mut nonce = L::Rid::default();
rng.fill_bytes(nonce.as_mut());
nonce
},
};
let hash_commit = udigest::hash::<D>(&unambiguous::HashCom {
sid,
prover: i,
decommitment: &decommitment,
});
tracer.send_msg();
let commitment = MsgRound1 {
commitment: hash_commit,
};
outgoings
.send(Outgoing::broadcast(Msg::Round1(commitment.clone())))
.await
.map_err(IoError::send_message)?;
tracer.msg_sent();
tracer.round_begins();
tracer.receive_msgs();
let commitments = rounds
.complete(round1)
.await
.map_err(IoError::receive_message)?;
tracer.msgs_received();
if reliable_broadcast_enforced {
tracer.stage("Hash received msgs (reliability check)");
let h_i = udigest::hash_iter::<D>(
commitments
.iter_including_me(&commitment)
.map(|commitment| unambiguous::Echo { sid, commitment }),
);
tracer.send_msg();
outgoings
.send(Outgoing::broadcast(Msg::ReliabilityCheck(
MsgReliabilityCheck(h_i),
)))
.await
.map_err(IoError::send_message)?;
tracer.msg_sent();
tracer.round_begins();
tracer.receive_msgs();
let hashes = rounds
.complete(round1_sync)
.await
.map_err(IoError::receive_message)?;
tracer.msgs_received();
tracer.stage("Assert other parties hashed messages (reliability check)");
let parties_have_different_hashes = hashes
.into_iter_indexed()
.filter(|(_j, _msg_id, h_j)| h_i != h_j.0)
.map(|(j, msg_id, _)| AbortBlame::new(j, msg_id, msg_id))
.collect::<Vec<_>>();
if !parties_have_different_hashes.is_empty() {
return Err(ProtocolAborted::round1_not_reliable(parties_have_different_hashes).into());
}
}
tracer.send_msg();
outgoings
.send(Outgoing::broadcast(Msg::Round2(decommitment.clone())))
.await
.map_err(IoError::send_message)?;
tracer.msg_sent();
tracer.round_begins();
tracer.receive_msgs();
let decommitments = rounds
.complete(round2)
.await
.map_err(IoError::receive_message)?;
tracer.msgs_received();
tracer.stage("Validate round 1 decommitments");
let blame = collect_blame(&decommitments, &commitments, |j, decomm, comm| {
let com_expected = udigest::hash::<D>(&unambiguous::HashCom {
sid,
prover: j,
decommitment: decomm,
});
com_expected != comm.commitment
});
if !blame.is_empty() {
return Err(ProtocolAborted::invalid_decommitment(blame).into());
}
tracer.stage("Validate П_prm (ψ_i)");
let blame = collect_blame(&decommitments, &decommitments, |j, d, _| {
if !crate::security_level::validate_public_paillier_key_size::<L>(&d.N) {
true
} else {
let data = π_prm::Data {
N: &d.N,
s: &d.s,
t: &d.t,
};
π_prm::verify::<{ crate::security_level::M }, D>(
&unambiguous::ProofPrm { sid, prover: j },
data,
&d.params_proof,
)
.is_err()
}
});
if !blame.is_empty() {
return Err(ProtocolAborted::invalid_ring_pedersen_parameters(blame).into());
}
tracer.stage("Add together shared random bytes");
let rho_bytes = decommitments
.iter()
.map(|d| &d.rho_bytes)
.fold(rho_bytes, utils::xor_array);
tracer.stage("Compute П_mod (ψ_i)");
let psi = π_mod::non_interactive::prove::<{ crate::security_level::M }, D>(
&unambiguous::ProofMod {
sid,
rho: rho_bytes.as_ref(),
prover: i,
},
&π_mod::Data { n: N.clone() },
&π_mod::PrivateData {
p: p.clone(),
q: q.clone(),
},
&mut rng,
)
.map_err(Bug::PiMod)?;
tracer.stage("Assemble security params for П_fac (ф_i)");
let π_fac_security = π_fac::SecurityParams {
l: L::ELL,
epsilon: L::EPSILON,
q: L::q(),
};
let n_sqrt = utils::sqrt(&N);
for (j, _, d) in decommitments.iter_indexed() {
tracer.send_msg();
tracer.stage("Compute П_fac (ф_i^j)");
let phi = π_fac::prove::<D>(
&unambiguous::ProofFac {
sid,
rho: rho_bytes.as_ref(),
prover: i,
},
&π_fac::Aux {
s: d.s.clone(),
t: d.t.clone(),
rsa_modulo: d.N.clone(),
multiexp: None,
crt: None,
},
π_fac::Data {
n: &N,
n_root: &n_sqrt,
},
π_fac::PrivateData { p: &p, q: &q },
&π_fac_security,
&mut rng,
)
.map_err(Bug::PiFac)?;
tracer.send_msg();
let msg = MsgRound3 {
mod_proof: psi.clone(),
fac_proof: phi.clone(),
};
outgoings
.feed(Outgoing::p2p(j, Msg::Round3(msg)))
.await
.map_err(IoError::send_message)?;
tracer.msg_sent();
}
tracer.send_msg();
outgoings.flush().await.map_err(IoError::send_message)?;
tracer.msg_sent();
tracer.round_begins();
tracer.receive_msgs();
let shares_msg_b = rounds
.complete(round3)
.await
.map_err(IoError::receive_message)?;
tracer.msgs_received();
tracer.stage("Validate ψ_j (П_mod)");
let blame = collect_blame(
&decommitments,
&shares_msg_b,
|j, decommitment, proof_msg| {
let data = π_mod::Data {
n: decommitment.N.clone(),
};
let (comm, proof) = &proof_msg.mod_proof;
π_mod::non_interactive::verify::<{ crate::security_level::M }, D>(
&unambiguous::ProofMod {
sid,
rho: rho_bytes.as_ref(),
prover: j,
},
&data,
comm,
proof,
)
.is_err()
},
);
if !blame.is_empty() {
return Err(ProtocolAborted::invalid_mod_proof(blame).into());
}
tracer.stage("Validate ф_j (П_fac)");
let crt = if compute_crt {
Some(paillier_zk::fast_paillier::utils::CrtExp::build_n(&p, &q).ok_or(Bug::BuildCrt)?)
} else {
None
};
let phi_common_aux = π_fac::Aux {
s: s.clone(),
t: t.clone(),
rsa_modulo: N.clone(),
multiexp: None,
crt: crt.clone(),
};
let blame = collect_blame(
&decommitments,
&shares_msg_b,
|j, decommitment, proof_msg| {
π_fac::verify::<D>(
&unambiguous::ProofFac {
sid,
rho: rho_bytes.as_ref(),
prover: j,
},
&phi_common_aux,
π_fac::Data {
n: &decommitment.N,
n_root: &utils::sqrt(&decommitment.N),
},
&π_fac_security,
&proof_msg.fac_proof,
)
.is_err()
},
);
if !blame.is_empty() {
return Err(ProtocolAborted::invalid_fac_proof(blame).into());
}
tracer.stage("Assemble auxiliary info");
let mut party_auxes = decommitments
.iter_including_me(&decommitment)
.map(|d| PartyAux {
N: d.N.clone(),
s: d.s.clone(),
t: d.t.clone(),
multiexp: None,
crt: None,
})
.collect::<Vec<_>>();
party_auxes[usize::from(i)].crt = crt;
let mut aux = DirtyAuxInfo {
p,
q,
parties: party_auxes,
security_level: std::marker::PhantomData,
};
if compute_multiexp_table {
tracer.stage("Precompute multiexp tables");
aux.precompute_multiexp_tables()
.map_err(Bug::BuildMultiexpTables)?;
}
let aux = aux
.validate()
.map_err(|err| Bug::InvalidShareGenerated(err.into_error()))?;
tracer.protocol_ends();
Ok(aux)
}