use std::{fmt::Display, ops::ControlFlow};
use thiserror::Error;
use crate::{
ConstantsTrait, HashError, HashableMessage, Integer, IntegerOperationError, OperationsTrait,
RecursiveHashTrait,
elgamal::{Ciphertext, ElgamalError},
integer::ModExponentiateError,
mix_net::{
MixNetResultTrait, MixnetError, MixnetErrorRepr,
commitments::{CommitmentError, get_commitment},
matrix::Matrix,
},
};
use super::ArgumentContext;
#[derive(Debug, Clone)]
pub struct MultiExponentiationStatement<'a> {
ciphertext_matrix: &'a Matrix<Ciphertext>,
upper_c: &'a Ciphertext,
cs_upper_a: &'a [Integer],
}
#[derive(Debug, Clone)]
pub struct MultiExponentiationArgument<'a> {
pub c_upper_a_0: &'a Integer,
pub cs_upper_b: &'a [Integer],
pub upper_es: &'a [Ciphertext],
pub a_vec: &'a [Integer],
pub r: &'a Integer,
pub b: &'a Integer,
pub s: &'a Integer,
pub tau: &'a Integer,
}
#[derive(Debug, Clone)]
pub struct MultiExponentiationArgumentVerifyInput<'a, 'b> {
statement: &'a MultiExponentiationStatement<'a>,
argument: &'b MultiExponentiationArgument<'b>,
}
#[derive(Debug, Eq, PartialEq)]
pub struct MultiExponentiationArgumentResult {
pub verif_upper_c_b_m: bool,
pub verif_upper_e_m: bool,
pub verif_upper_a: bool,
pub verif_upper_b: bool,
pub verif_upper_e_upper_c: bool,
}
#[derive(Error, Debug)]
pub enum MultiExponentiationArgumentError {
#[error("Ciphertext matrix is malformed")]
CyphertextMatrixMalformed,
#[error("Ciphertext not same length in {0}")]
CyphertextNotSameL(String),
#[error("Commitment vectors c_b is not equal to ciphertext vector")]
CommitmentVectorNotSameLen,
#[error("{0} is not consistent")]
ValueNotConsistent(String),
#[error("{0} is too small")]
SizeTooSmall(String),
#[error("Error for x")]
X { source: HashError },
#[error("Error calculating the pwoers of x")]
XPowers { source: ModExponentiateError },
#[error("Error calculating the product of A")]
ProdA { source: IntegerOperationError },
#[error("error calculation Commitment A")]
CommitmentA { source: CommitmentError },
#[error("Error calculating the product of B")]
ProdB { source: IntegerOperationError },
#[error("error calculation Commitment B")]
CommitmentB { source: CommitmentError },
#[error("error calculation product E")]
ProdE { source: ElgamalError },
#[error("error calculation g^b mod p")]
GExpBModP { source: ModExponentiateError },
#[error("error encrypting g^b mod p")]
EncryptionGExpModP { source: ElgamalError },
}
pub fn verify_multi_exponentiation_argument(
context: &ArgumentContext,
input: &MultiExponentiationArgumentVerifyInput,
) -> Result<MultiExponentiationArgumentResult, MultiExponentiationArgumentError> {
let statement = input.statement;
let argument = input.argument;
let p = context.ep.p();
let q = context.ep.q();
let g = context.ep.g();
let m = statement.m();
let l = statement.l();
let x = get_x(context, statement, argument)
.map_err(|e| MultiExponentiationArgumentError::X { source: e })?;
let x_powers = (0..2 * m)
.map(|i| x.mod_exponentiate(&Integer::from(i), q))
.collect::<Result<Vec<_>, _>>()
.map_err(|e| MultiExponentiationArgumentError::XPowers { source: e })?;
let verif_upper_c_b_m = &argument.cs_upper_b[m] == Integer::one();
let verif_upper_e_m = &argument.upper_es[m] == statement.upper_c;
let prod_upper_c_a = argument.c_upper_a_0.mod_multiply(
&Integer::mod_multi_exponentiate_iter(
&mut statement.cs_upper_a.iter(),
&mut x_powers.iter().skip(1),
p,
)
.map_err(|e| MultiExponentiationArgumentError::ProdA { source: e })?,
p,
);
let comm_upper_a = get_commitment(context.ep, argument.a_vec, argument.r, context.ck)
.map_err(|e| MultiExponentiationArgumentError::CommitmentA { source: e })?;
let verif_upper_a = prod_upper_c_a == comm_upper_a;
let prod_upper_c_b = Integer::mod_multi_exponentiate_iter(
&mut argument.cs_upper_b.iter(),
&mut x_powers.iter(),
p,
)
.map_err(|e| MultiExponentiationArgumentError::ProdB { source: e })?;
let comm_upper_b = get_commitment(
context.ep,
std::slice::from_ref(argument.b),
argument.s,
context.ck,
)
.map_err(|e| MultiExponentiationArgumentError::CommitmentB { source: e })?;
let verif_upper_b = prod_upper_c_b == comm_upper_b;
let prod_upper_e = match argument
.upper_es
.iter()
.zip(x_powers.iter())
.skip(1)
.map(|(e_k, x_k)| e_k.get_ciphertext_exponentiation(x_k, context.ep))
.try_fold(argument.upper_es[0].clone(), |acc, e_res| match e_res {
Ok(e) => ControlFlow::Continue(acc.get_ciphertext_product(&e, context.ep)),
Err(e) => ControlFlow::Break(e),
}) {
ControlFlow::Continue(v) => Ok(v),
ControlFlow::Break(e) => Err(MultiExponentiationArgumentError::ProdE { source: e }),
}?;
let encrypted_upper_g_b = Ciphertext::get_ciphertext(
context.ep,
vec![
g.mod_exponentiate(argument.b, p)
.map_err(|e| MultiExponentiationArgumentError::GExpBModP { source: e })?;
l
]
.as_slice(),
argument.tau,
context.pks,
)
.map_err(|e| MultiExponentiationArgumentError::EncryptionGExpModP { source: e })?;
let prod_c = match statement
.ciphertext_matrix
.rows_iter()
.zip(x_powers.iter().take(m).rev())
.map(|(c_i, x_m_minus_i_minus_1)| {
Ciphertext::get_ciphertext_vector_exponentiation(
c_i.to_vec().as_slice(),
x_m_minus_i_minus_1
.mod_scalar_multiply(argument.a_vec, q)
.as_slice(),
context.ep,
)
})
.try_fold(
Ciphertext::neutral_for_mod_multiply(l),
|acc, c_res| match c_res {
Ok(c) => ControlFlow::Continue(acc.get_ciphertext_product(&c, context.ep)),
Err(e) => ControlFlow::Break(e),
},
) {
ControlFlow::Continue(c) => Ok(c),
ControlFlow::Break(e) => Err(MultiExponentiationArgumentError::ProdE { source: e }),
}?;
let verif_upper_e_upper_c =
prod_upper_e == encrypted_upper_g_b.get_ciphertext_product(&prod_c, context.ep);
Ok(MultiExponentiationArgumentResult {
verif_upper_c_b_m,
verif_upper_e_m,
verif_upper_a,
verif_upper_b,
verif_upper_e_upper_c,
})
}
pub fn get_x(
context: &ArgumentContext,
statement: &MultiExponentiationStatement,
argument: &MultiExponentiationArgument,
) -> Result<Integer, HashError> {
Ok(HashableMessage::from(vec![
HashableMessage::from(context.ep.p()),
HashableMessage::from(context.ep.q()),
HashableMessage::from(context.pks),
HashableMessage::from(context.ck),
HashableMessage::from(statement.ciphertext_matrix),
HashableMessage::from(statement.upper_c),
HashableMessage::from(statement.cs_upper_a),
HashableMessage::from(argument.c_upper_a_0),
HashableMessage::from(argument.cs_upper_b),
HashableMessage::from(argument.upper_es),
])
.recursive_hash()?
.into_integer())
}
impl<'a> MultiExponentiationStatement<'a> {
pub fn new(
ciphertext_matrix: &'a Matrix<Ciphertext>,
upper_c: &'a Ciphertext,
cs_upper_a: &'a [Integer],
) -> Result<Self, MultiExponentiationArgumentError> {
if ciphertext_matrix.is_malformed() {
return Err(MultiExponentiationArgumentError::CyphertextMatrixMalformed);
}
if ciphertext_matrix.nb_rows() != cs_upper_a.len() {
return Err(MultiExponentiationArgumentError::CommitmentVectorNotSameLen);
}
let l = upper_c.l();
for j in 0..ciphertext_matrix.nb_columns() {
let col = ciphertext_matrix.column(j);
if !col.iter().all(|e| e.l() == l) {
return Err(MultiExponentiationArgumentError::CyphertextNotSameL(
"MultiExponentiationStatement (C to ciphertext_matrix)".to_string(),
));
}
}
Ok(Self {
ciphertext_matrix,
upper_c,
cs_upper_a,
})
}
pub fn m(&self) -> usize {
self.cs_upper_a.len()
}
pub fn n(&self) -> usize {
self.ciphertext_matrix.nb_columns()
}
pub fn l(&self) -> usize {
self.upper_c.l()
}
}
#[allow(clippy::too_many_arguments)]
impl<'a> MultiExponentiationArgument<'a> {
pub fn new(
c_upper_a_0: &'a Integer,
cs_upper_b: &'a [Integer],
upper_es: &'a [Ciphertext],
a_vec: &'a [Integer],
r: &'a Integer,
b: &'a Integer,
s: &'a Integer,
tau: &'a Integer,
) -> Result<Self, MixnetError> {
Self::new_impl(c_upper_a_0, cs_upper_b, upper_es, a_vec, r, b, s, tau)
.map_err(MixnetErrorRepr::from)
.map_err(|e| MixnetError {
source: Box::new(e),
})
}
fn new_impl(
c_upper_a_0: &'a Integer,
cs_upper_b: &'a [Integer],
upper_es: &'a [Ciphertext],
a_vec: &'a [Integer],
r: &'a Integer,
b: &'a Integer,
s: &'a Integer,
tau: &'a Integer,
) -> Result<Self, MultiExponentiationArgumentError> {
if cs_upper_b.len() != upper_es.len() {
return Err(MultiExponentiationArgumentError::CommitmentVectorNotSameLen);
}
let l = upper_es[0].l();
if !upper_es.iter().all(|e| e.l() == l) {
return Err(MultiExponentiationArgumentError::CyphertextNotSameL(
"MultiExponentiationArgument (in E)".to_string(),
));
}
Ok(Self {
c_upper_a_0,
cs_upper_b,
upper_es,
a_vec,
r,
b,
s,
tau,
})
}
pub fn m(&self) -> usize {
self.cs_upper_b.len() / 2
}
pub fn n(&self) -> usize {
self.a_vec.len()
}
pub fn l(&self) -> usize {
self.upper_es[0].l()
}
}
impl<'a, 'b> MultiExponentiationArgumentVerifyInput<'a, 'b> {
pub fn new(
statement: &'a MultiExponentiationStatement<'a>,
argument: &'b MultiExponentiationArgument<'b>,
) -> Result<Self, MultiExponentiationArgumentError> {
if statement.m() != argument.m() {
return Err(MultiExponentiationArgumentError::ValueNotConsistent(
"m".to_string(),
));
}
if statement.n() != argument.n() {
return Err(MultiExponentiationArgumentError::ValueNotConsistent(
"n".to_string(),
));
}
if statement.l() != argument.l() {
return Err(MultiExponentiationArgumentError::ValueNotConsistent(
"l".to_string(),
));
}
if statement.m() == 0 {
return Err(MultiExponentiationArgumentError::SizeTooSmall(
"m".to_string(),
));
}
if statement.n() == 0 {
return Err(MultiExponentiationArgumentError::SizeTooSmall(
"n".to_string(),
));
}
Ok(Self {
statement,
argument,
})
}
}
impl MixNetResultTrait for MultiExponentiationArgumentResult {
fn is_ok(&self) -> bool {
self.verif_upper_a
&& self.verif_upper_b
&& self.verif_upper_c_b_m
&& self.verif_upper_e_m
&& self.verif_upper_e_upper_c
}
}
impl Display for MultiExponentiationArgumentResult {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if self.is_ok() {
return write!(f, "verification ok");
}
write!(
f,
"verifCbm: {}, verifEm: {}, verifA: {}, verifB: {}, verifEC: {}",
self.verif_upper_c_b_m,
self.verif_upper_e_m,
self.verif_upper_a,
self.verif_upper_b,
self.verif_upper_e_upper_c
)
}
}
#[cfg(test)]
pub mod test {
use super::*;
use crate::mix_net::arguments::test_json_data::json_to_context_values;
use crate::test_json_data::{
get_test_cases_from_json_file, json_64_value_to_integer,
json_array_64_value_to_array_integer, json_values_to_ciphertext,
};
use serde_json::Value;
pub fn get_ciphertexts(value: &Value) -> Vec<Ciphertext> {
value
.as_array()
.unwrap()
.iter()
.map(json_values_to_ciphertext)
.collect()
}
pub fn get_ciphertext_matrix(value: &Value) -> Matrix<Ciphertext> {
let temp: Vec<Vec<Ciphertext>> = value
.as_array()
.unwrap()
.iter()
.map(get_ciphertexts)
.collect();
Matrix::from_rows(&temp).unwrap()
}
pub struct MEStatementValues(pub Matrix<Ciphertext>, pub Ciphertext, pub Vec<Integer>);
pub struct MEArgumentValues(
pub Integer,
pub Vec<Integer>,
pub Vec<Ciphertext>,
pub Vec<Integer>,
pub Integer,
pub Integer,
pub Integer,
pub Integer,
);
fn get_statement_values(statement: &Value) -> MEStatementValues {
MEStatementValues(
get_ciphertext_matrix(&statement["ciphertexts"]),
json_values_to_ciphertext(&statement["ciphertext_product"]),
json_array_64_value_to_array_integer(&statement["c_a"]),
)
}
fn get_statement(values: &MEStatementValues) -> MultiExponentiationStatement<'_> {
MultiExponentiationStatement::new(&values.0, &values.1, &values.2).unwrap()
}
pub fn get_argument_values(argument: &Value) -> MEArgumentValues {
MEArgumentValues(
json_64_value_to_integer(&argument["c_a_0"]),
json_array_64_value_to_array_integer(&argument["c_b"]),
get_ciphertexts(&argument["e"]),
json_array_64_value_to_array_integer(&argument["a"]),
json_64_value_to_integer(&argument["r"]),
json_64_value_to_integer(&argument["b"]),
json_64_value_to_integer(&argument["s"]),
json_64_value_to_integer(&argument["tau"]),
)
}
pub fn get_argument(values: &MEArgumentValues) -> MultiExponentiationArgument<'_> {
MultiExponentiationArgument::new(
&values.0, &values.1, &values.2, &values.3, &values.4, &values.5, &values.6, &values.7,
)
.unwrap()
}
#[test]
fn test_verify() {
for tc in get_test_cases_from_json_file("mixnet", "verify-multiexp-argument.json").iter() {
let context_values = json_to_context_values(&tc["context"]);
let context = ArgumentContext::from(&context_values);
let statement_values = get_statement_values(&tc["input"]["statement"]);
let statement = get_statement(&statement_values);
let argument_values = get_argument_values(&tc["input"]["argument"]);
let argument = get_argument(&argument_values);
let input = MultiExponentiationArgumentVerifyInput::new(&statement, &argument).unwrap();
let x_res = get_x(&context, &statement, &argument);
assert!(
x_res.is_ok(),
"Error unwraping x {}: {}",
tc["description"],
x_res.unwrap_err()
);
assert_eq!(
x_res.unwrap(),
json_64_value_to_integer(&tc["output"]["x"]),
"Verifying x: {}",
tc["description"]
);
let res = verify_multi_exponentiation_argument(&context, &input);
assert!(
res.is_ok(),
"Error unwraping {}: {}",
tc["description"],
res.unwrap_err()
);
assert!(
res.as_ref().unwrap().is_ok(),
"Verification for {} not ok: {}",
tc["description"],
res.as_ref().unwrap()
);
}
}
}