use crate::{
common::*,
curve_arithmetic::{Curve, Field},
pedersen_commitment::Value as PedersenValue,
};
use anyhow::bail;
use rand::*;
use serde::de::Error;
use serde_json::{json, Value};
use std::convert::TryFrom;
#[derive(Debug, PartialEq, Eq, PartialOrd, Hash, Ord, Clone, Copy, Serial)]
#[derive(SerdeSerialize)]
#[serde(transparent)]
pub struct Threshold(u8);
impl<'de> SerdeDeserialize<'de> for Threshold {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: SerdeDeserializer<'de>,
{
let val = u8::deserialize(deserializer)?;
Self::try_new(val).map_err(D::Error::custom)
}
}
impl Deserial for Threshold {
fn deserial<R: ReadBytesExt>(source: &mut R) -> ParseResult<Self> {
let x: u8 = source.get()?;
if x >= 1 {
Ok(Threshold(x))
} else {
bail!("Threshold must be at least 1.")
}
}
}
#[derive(Debug, thiserror::Error)]
#[error("threshold cannot be zero")]
pub struct ThresholdZero;
impl Threshold {
pub fn try_new(threshold: u8) -> Result<Self, ThresholdZero> {
if threshold >= 1 {
Ok(Threshold(threshold))
} else {
Err(ThresholdZero)
}
}
pub fn threshold(self) -> u8 {
self.0
}
pub fn to_scalar<C: Curve>(self) -> C::Scalar {
C::scalar_from_u64(u64::from(self.0))
}
pub fn to_json(self) -> Value {
json!(self.0)
}
pub fn from_json(v: &Value) -> Option<Self> {
let v = u8::try_from(v.as_u64()?).ok()?;
if v >= 1 {
Some(Threshold(v))
} else {
None
}
}
}
impl From<Threshold> for u8 {
fn from(x: Threshold) -> Self {
x.0
}
}
impl From<Threshold> for usize {
fn from(x: Threshold) -> Self {
x.0.into()
}
}
impl TryFrom<u8> for Threshold {
type Error = ();
fn try_from(value: u8) -> Result<Self, Self::Error> {
Self::try_new(value).map_err(|_| ())
}
}
impl TryFrom<usize> for Threshold {
type Error = ();
fn try_from(value: usize) -> Result<Self, Self::Error> {
u8::try_from(value)
.map_err(|_| ())
.and_then(|val| Self::try_from(val).map_err(|_| ()))
}
}
impl std::fmt::Display for Threshold {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}
pub struct SharingData<C: Curve> {
pub coefficients: Vec<PedersenValue<C>>,
pub shares: Vec<PedersenValue<C>>,
}
pub fn share<C: Curve, P: Into<u64>, I: IntoIterator<Item = P> + ExactSizeIterator, R: Rng>(
secret: &C::Scalar,
points: I,
revealing_threshold: Threshold,
csprng: &mut R,
) -> SharingData<C> {
debug_assert!(revealing_threshold >= Threshold(1));
let deg: u8 = revealing_threshold.into();
let deg = deg - 1;
let mut coefficients: Vec<PedersenValue<C>> = Vec::with_capacity(deg as usize);
for _ in 1..deg {
let r = PedersenValue::generate(csprng);
coefficients.push(r);
}
if deg > 0 {
let r = PedersenValue::generate_non_zero(csprng);
coefficients.push(r);
}
let number_of_shares = points.len();
let mut shares = Vec::with_capacity(number_of_shares);
for p in points {
let x = C::scalar_from_u64(p.into());
let mut share: C::Scalar = C::Scalar::zero();
for coeff in coefficients.iter().rev() {
share.mul_assign(&x);
share.add_assign(coeff);
}
share.mul_assign(&x);
share.add_assign(secret);
shares.push(PedersenValue::new(share))
}
SharingData {
coefficients,
shares,
}
}
fn lagrange<P: Into<u64> + Copy, C: Curve>(kxs: &[P], i: P) -> C::Scalar {
let point = C::scalar_from_u64(i.into());
kxs.iter().fold(C::Scalar::one(), |accum, &j| {
let mut fe_j = C::scalar_from_u64(j.into());
let mut j_minus_i = fe_j;
j_minus_i.sub_assign(&point);
match j_minus_i.inverse() {
None => accum,
Some(z) => {
fe_j.mul_assign(&z);
fe_j.mul_assign(&accum);
fe_j
}
}
})
}
pub fn reveal<P: Into<u64> + Copy, C: Curve>(shares: &[(P, PedersenValue<C>)]) -> C::Scalar {
let kxs = shares.iter().map(|(fst, _)| *fst).collect::<Vec<_>>();
shares.iter().fold(C::Scalar::zero(), |accum, (i, v)| {
let mut s = lagrange::<P, C>(&kxs, *i);
s.mul_assign(v);
s.add_assign(&accum);
s
})
}
pub fn reveal_in_group<P: Into<u64> + Copy, C: Curve>(shares: &[(P, C)]) -> C {
let kxs = shares.iter().map(|(fst, _)| *fst).collect::<Vec<_>>();
shares.iter().fold(C::zero_point(), |accum, (i, v)| {
let s = lagrange::<P, C>(&kxs, *i);
let vs = v.mul_by_scalar(&s);
vs.plus_point(&accum)
})
}
#[cfg(test)]
mod test {
use crate::curve_arithmetic::arkworks_instances::{ArkField, ArkGroup};
use super::*;
use crate::common;
use ark_bls12_381::{Fr, G1Projective};
use nom::AsBytes;
use rand::seq::SliceRandom;
type G1 = ArkGroup<G1Projective>;
#[test]
pub fn test_lagrange() {
let mut kxs = Vec::<u32>::new();
kxs.push(1);
kxs.push(2);
kxs.push(3);
let p = 0;
let r = lagrange::<u32, G1>(&kxs, p);
assert_eq!(r, G1::scalar_from_u64(1));
let kxs = vec![1, 2];
let p = 1;
let r = lagrange::<u32, G1>(&kxs, p);
assert_eq!(r, G1::scalar_from_u64(2));
}
#[test]
pub fn test_share_output_length() {
let mut csprng = thread_rng();
let secret = <G1 as Curve>::generate_scalar(&mut csprng);
let n = std::cmp::max(1, std::cmp::min(200, csprng.gen::<u32>()));
let mut xs = (0..n).collect::<Vec<_>>();
xs.shuffle(&mut csprng);
let t = csprng.gen_range(1..xs.len() + 1);
let shared = share::<G1, _, _, _>(&secret, xs.into_iter(), Threshold(t as u8), &mut csprng);
assert_eq!(shared.coefficients.len() + 1, t);
assert_eq!(shared.shares.len(), n as usize);
}
#[test]
pub fn test_secret_sharing() {
let mut csprng = thread_rng();
for i in 1u8..10 {
let generator = G1::one_point()
.mul_by_scalar(&<G1 as Curve>::generate_non_zero_scalar(&mut csprng));
let secret = <G1 as Curve>::generate_scalar(&mut csprng);
let secret_point = generator.mul_by_scalar(&secret);
let threshold = csprng.gen_range(1..i + 1);
let mut xs = (1..=i).collect::<Vec<_>>();
xs.shuffle(&mut csprng);
let sharing_data = share::<G1, _, _, _>(
&secret,
xs.iter().copied(),
Threshold::try_from(threshold).expect("Threshold is at least 1."),
&mut csprng,
);
let mut shares = xs
.iter()
.copied()
.zip(sharing_data.shares)
.collect::<Vec<_>>();
shares.shuffle(&mut csprng);
let sufficient_sample = &shares[0..(threshold as usize)];
let sufficient_sample_points = sufficient_sample
.iter()
.map(|(n, s)| (*n, generator.mul_by_scalar(s)))
.collect::<Vec<(u8, G1)>>();
let revealed_data: ArkField<Fr> = reveal::<_, G1>(sufficient_sample);
assert_eq!(revealed_data, secret);
let revealed_data_point: G1 = reveal_in_group::<_, G1>(&sufficient_sample_points);
assert_eq!(revealed_data_point, secret_point);
let sharing_data = share::<G1, _, _, _>(
&secret,
xs.iter().copied(),
Threshold::try_from(threshold).expect("Threshold is at least 1."),
&mut csprng,
);
let mut shares = xs
.iter()
.copied()
.zip(sharing_data.shares)
.collect::<Vec<_>>();
shares.shuffle(&mut csprng);
shares.truncate(threshold as usize);
let rand_elm = shares.choose_mut(&mut csprng).unwrap();
rand_elm.1 = crate::curve_arithmetic::Value::generate(&mut csprng);
let revealed_data: ArkField<Fr> = reveal::<_, G1>(&shares);
assert_ne!(revealed_data, secret);
let sufficient_points_err = shares
.iter()
.map(|(n, s)| (*n, generator.mul_by_scalar(s)))
.collect::<Vec<(u8, G1)>>();
let revealed_data_point: G1 = reveal_in_group::<_, G1>(&sufficient_points_err);
assert_ne!(revealed_data_point, secret_point);
let sharing_data = share::<G1, _, _, _>(
&secret,
xs.iter().copied(),
Threshold::try_from(threshold).expect("Threshold is at least 1."),
&mut csprng,
);
let mut insufficient_shares = xs
.iter()
.copied()
.zip(sharing_data.shares)
.collect::<Vec<_>>();
insufficient_shares.shuffle(&mut csprng);
let insufficient_sample = &insufficient_shares[0..((threshold - 1) as usize)];
let insufficient_sample_points = insufficient_sample
.iter()
.map(|(n, s)| (*n, generator.mul_by_scalar(s)))
.collect::<Vec<(u8, G1)>>();
let revealed_data: ArkField<Fr> = reveal::<_, G1>(insufficient_sample);
assert_ne!(revealed_data, secret);
let revealed_data_point: G1 = reveal_in_group::<_, G1>(&insufficient_sample_points);
assert_ne!(revealed_data_point, secret_point);
}
}
#[test]
fn test_threshold_try_new() {
let threshold = Threshold::try_new(1).expect("try_new");
assert_eq!(threshold.threshold(), 1);
assert_eq!(u8::from(threshold), 1);
assert_eq!(usize::from(threshold), 1);
let threshold = Threshold::try_new(3).expect("try_new");
assert_eq!(threshold.threshold(), 3);
Threshold::try_new(0).expect_err("try_new");
}
#[test]
fn test_threshold_try_from() {
let threshold = Threshold::try_from(1u8).expect("try_from");
assert_eq!(threshold.threshold(), 1);
let threshold = Threshold::try_from(1usize).expect("try_from");
assert_eq!(threshold.threshold(), 1);
let threshold = Threshold::try_from(3u8).expect("try_from");
assert_eq!(threshold.threshold(), 3);
let threshold = Threshold::try_from(255usize).expect("try_from");
assert_eq!(threshold.threshold(), 255);
Threshold::try_from(0u8).expect_err("try_from");
Threshold::try_from(256usize).expect_err("try_from");
}
#[test]
fn test_threshold_serial_deserial() {
let threshold = Threshold::try_new(2).unwrap();
let bytes_hex = hex::encode(common::to_bytes(&threshold));
assert_eq!(bytes_hex, "02");
let threshold_deserial: Threshold =
common::from_bytes(&mut hex::decode(bytes_hex).unwrap().as_bytes()).expect("deserial");
assert_eq!(threshold_deserial, threshold);
let bytes_hex = "00";
let err =
common::from_bytes::<Threshold, _>(&mut hex::decode(bytes_hex).unwrap().as_bytes())
.expect_err("deserial");
assert!(
err.to_string().contains("Threshold must be at least 1"),
"message: {}",
err
);
}
#[test]
fn test_threshold_serde_serialize_deserialize() {
let threshold = Threshold::try_new(2).unwrap();
let json = serde_json::to_string(&threshold).expect("serialize");
assert_eq!(json, r#"2"#);
let threshold_deserialized: Threshold = serde_json::from_str(&json).expect("deserialize");
assert_eq!(threshold_deserialized, threshold);
let json = r#"0"#;
let err = serde_json::from_str::<Threshold>(&json).expect_err("deserial");
assert!(
err.to_string().contains("threshold cannot be zero"),
"message: {}",
err
);
}
#[test]
fn test_threshold_json_value() {
let threshold = Threshold::from_json(&serde_json::Value::from(1u8)).unwrap();
assert_eq!(threshold.threshold(), 1);
assert_eq!(threshold.to_json(), serde_json::Value::from(1u8));
assert!(Threshold::from_json(&serde_json::Value::from(0u8)).is_none());
}
}