use digest::Digest;
use serde::{Deserialize, Serialize};
use thiserror::Error;
use crate::cryptographic_primitives::hashing::DigestExt;
use crate::cryptographic_primitives::proofs::ProofError;
use crate::cryptographic_primitives::secret_sharing::Polynomial;
use crate::elliptic::curves::{Curve, Point, Scalar};
use crate::HashChoice;
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(bound = "")]
pub struct LdeiWitness<E: Curve> {
pub w: Polynomial<E>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(bound = "")]
pub struct LdeiStatement<E: Curve> {
pub alpha: Vec<Scalar<E>>,
pub g: Vec<Point<E>>,
pub x: Vec<Point<E>>,
pub d: u16,
}
impl<E: Curve> LdeiStatement<E> {
pub fn new(
witness: &LdeiWitness<E>,
alpha: Vec<Scalar<E>>,
g: Vec<Point<E>>,
d: u16,
) -> Result<Self, InvalidLdeiStatement> {
if g.len() != alpha.len() {
return Err(InvalidLdeiStatement::AlphaLengthDoesntMatchG);
}
if witness.w.degree() > d.into() {
return Err(InvalidLdeiStatement::PolynomialDegreeMoreThanD);
}
if !ensure_list_is_pairwise_distinct(&alpha) {
return Err(InvalidLdeiStatement::AlphaNotPairwiseDistinct);
}
Ok(Self {
x: g.iter()
.zip(&alpha)
.map(|(g, a)| g * witness.w.evaluate(a))
.collect(),
alpha,
g,
d,
})
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(bound = "")]
pub struct LdeiProof<E: Curve, H: Digest + Clone> {
pub a: Vec<Point<E>>,
pub e: Scalar<E>,
pub z: Polynomial<E>,
#[serde(skip)]
pub hash_choice: HashChoice<H>,
}
impl<E: Curve, H: Digest + Clone> LdeiProof<E, H> {
#[allow(clippy::many_single_char_names)]
pub fn prove(
witness: &LdeiWitness<E>,
statement: &LdeiStatement<E>,
) -> Result<LdeiProof<E, H>, InvalidLdeiStatement> {
if statement.alpha.len() != statement.g.len() {
return Err(InvalidLdeiStatement::AlphaLengthDoesntMatchG);
}
if witness.w.degree() > statement.d.into() {
return Err(InvalidLdeiStatement::PolynomialDegreeMoreThanD);
}
if !ensure_list_is_pairwise_distinct(&statement.alpha) {
return Err(InvalidLdeiStatement::AlphaNotPairwiseDistinct);
}
let x_expected: Vec<Point<E>> = statement
.g
.iter()
.zip(&statement.alpha)
.map(|(g, a)| g * witness.w.evaluate(a))
.collect();
if statement.x != x_expected {
return Err(InvalidLdeiStatement::ListOfXDoesntMatchExpectedValue);
}
let u = Polynomial::<E>::sample_exact(statement.d);
let a: Vec<Point<E>> = statement
.g
.iter()
.zip(&statement.alpha)
.map(|(g, a)| g * u.evaluate(a))
.collect();
let e = H::new()
.chain_points(&statement.g)
.chain_points(&statement.x)
.chain_points(&a)
.result_scalar();
let z = &u - &(&witness.w * &e);
Ok(LdeiProof {
a,
e,
z,
hash_choice: HashChoice::new(),
})
}
pub fn verify(&self, statement: &LdeiStatement<E>) -> Result<(), ProofError>
where
H: Digest + Clone,
{
let e = H::new()
.chain_points(&statement.g)
.chain_points(&statement.x)
.chain_points(&self.a)
.result_scalar();
if e != self.e {
return Err(ProofError);
}
if self.z.degree() > statement.d.into() {
return Err(ProofError);
}
let expected_a: Vec<_> = statement
.g
.iter()
.zip(&statement.alpha)
.zip(&statement.x)
.map(|((g, a), x)| g * self.z.evaluate(a) + x * &e)
.collect();
if self.a == expected_a {
Ok(())
} else {
Err(ProofError)
}
}
}
#[derive(Debug, Clone, Error)]
pub enum InvalidLdeiStatement {
#[error("`alpha`s are not pairwise distinct")]
AlphaNotPairwiseDistinct,
#[error("alpha.len() != g.len()")]
AlphaLengthDoesntMatchG,
#[error("deg(w) > d")]
PolynomialDegreeMoreThanD,
#[error("`statement.x` doesn't match expected value")]
ListOfXDoesntMatchExpectedValue,
}
fn ensure_list_is_pairwise_distinct<S: PartialEq>(list: &[S]) -> bool {
for (i, x1) in list.iter().enumerate() {
for (j, x2) in list.iter().enumerate() {
if i != j && x1 == x2 {
return false;
}
}
}
true
}
#[cfg(test)]
mod tests {
use std::iter;
use crate::elliptic::curves::{Curve, Scalar};
use crate::test_for_all_curves_and_hashes;
use super::*;
test_for_all_curves_and_hashes!(correctly_proofs);
fn correctly_proofs<E: Curve, H: Digest + Clone>() {
let d = 5;
let poly = Polynomial::<E>::sample_exact(5);
let witness = LdeiWitness { w: poly };
let alpha: Vec<Scalar<E>> = (1..=10).map(Scalar::from).collect();
let g: Vec<Point<E>> = iter::repeat_with(Scalar::random)
.map(|x| Point::generator() * x)
.take(10)
.collect();
let statement = LdeiStatement::new(&witness, alpha, g, d).unwrap();
let proof = LdeiProof::<_, H>::prove(&witness, &statement).expect("failed to prove");
proof.verify(&statement).expect("failed to validate proof");
}
}