use std::{
collections::{BTreeMap, HashMap},
iter::repeat_n,
};
use enum_try_as_inner::EnumTryAsInner;
use ff::Field;
use itertools::{izip, Itertools};
use num_bigint::BigUint;
use num_traits::{FromBytes, One};
use primitives::{
algebra::{
elliptic_curve::{BaseField, BaseFieldElement, Curve, Point, Scalar, ScalarField},
field::{
mersenne::Mersenne107,
subfield_element::Mersenne107Element,
Bit,
FieldExtension,
SubfieldElement,
},
},
izip_eq,
};
use rand::Rng;
use typenum::Unsigned;
use crate::{
circuit::{
AlgebraicType,
BitShareBinaryOp,
BitShareUnaryOp,
Circuit,
Constant,
FieldShareBinaryOp,
FieldShareUnaryOp,
FieldType,
Gate,
GateExt,
GateIndex,
Input,
PointPlaintextBinaryOp,
PointShareBinaryOp,
PointShareUnaryOp,
ShareOrPlaintext,
},
key_recovery::compute_errors::compute_errors,
};
#[derive(Debug, Clone, EnumTryAsInner)]
#[derive_err(Debug, Clone)]
pub enum RunTimeVal<C: Curve> {
PlaintextScalar(Vec<Scalar<C>>),
SecretSharedScalar(Vec<Scalar<C>>),
PlaintextBase(Vec<BaseFieldElement<C>>),
SecretSharedBase(Vec<BaseFieldElement<C>>),
PlaintextMersenne107(Vec<Mersenne107Element>),
SecretSharedMersenne107(Vec<Mersenne107Element>),
PlaintextBit(Vec<bool>),
SecretSharedBit(Vec<bool>),
DaBit(Vec<bool>),
PlaintextPoint(Vec<Point<C>>),
SecretSharedPoint(Vec<Point<C>>),
}
impl<C: Curve> RunTimeVal<C> {
fn into_big_uint(self) -> Vec<BigUint> {
match self {
RunTimeVal::PlaintextScalar(c) | RunTimeVal::SecretSharedScalar(c) => {
c.into_iter().map(|s| s.to_biguint()).collect()
}
RunTimeVal::PlaintextBase(c) | RunTimeVal::SecretSharedBase(c) => {
c.into_iter().map(|s| s.to_biguint()).collect()
}
RunTimeVal::PlaintextMersenne107(c) | RunTimeVal::SecretSharedMersenne107(c) => {
c.into_iter().map(|s| s.to_biguint()).collect()
}
RunTimeVal::PlaintextBit(c) | RunTimeVal::SecretSharedBit(c) | RunTimeVal::DaBit(c) => {
c.into_iter().map(|b| b.into()).collect()
}
RunTimeVal::PlaintextPoint(p) | RunTimeVal::SecretSharedPoint(p) => p
.into_iter()
.map(|p| BigUint::from_le_bytes(p.to_bytes().as_ref()))
.collect(),
}
}
fn len(&self) -> usize {
match self {
RunTimeVal::PlaintextScalar(v) => v.len(),
RunTimeVal::SecretSharedScalar(v) => v.len(),
RunTimeVal::PlaintextBase(v) => v.len(),
RunTimeVal::SecretSharedBase(v) => v.len(),
RunTimeVal::PlaintextMersenne107(v) => v.len(),
RunTimeVal::SecretSharedMersenne107(v) => v.len(),
RunTimeVal::PlaintextBit(v) => v.len(),
RunTimeVal::SecretSharedBit(v) => v.len(),
RunTimeVal::DaBit(v) => v.len(),
RunTimeVal::PlaintextPoint(v) => v.len(),
RunTimeVal::SecretSharedPoint(v) => v.len(),
}
}
}
fn big_uint_to_field<F: FieldExtension>(a: BigUint) -> SubfieldElement<F> {
let mut bytes = a.to_bytes_le();
bytes.extend(repeat_n(0u8, F::FieldBytesSize::to_usize() - bytes.len()));
SubfieldElement::<F>::from_le_bytes(&bytes).unwrap()
}
fn big_uint_to_bit(a: BigUint) -> bool {
if a > BigUint::from(1u8) {
panic!("Cannot convert BigUint to Bit: value is greater than 1");
}
a.is_one()
}
fn big_uint_to_point<C: Curve>(a: BigUint) -> Point<C> {
let mut bytes = a.to_bytes_le();
bytes.extend(repeat_n(0u8, 32 - bytes.len()));
assert_eq!(bytes.len(), 32);
Point::from_le_bytes(&bytes).unwrap()
}
fn bit_to_big_uint(a: bool) -> BigUint {
if a {
BigUint::from(1u8)
} else {
BigUint::from(0u8)
}
}
fn point_to_big_uint<C: Curve>(p: Point<C>) -> BigUint {
BigUint::from_le_bytes(p.to_bytes().as_ref())
}
#[derive(Debug, Default)]
struct IntermediateValues<C: Curve>(Vec<RunTimeVal<C>>);
impl<C: Curve> IntermediateValues<C> {
fn push(&mut self, val: RunTimeVal<C>) {
self.0.push(val);
}
fn get(&self, index: &GateIndex) -> Option<&RunTimeVal<C>> {
self.0.get(*index as usize)
}
fn get_secret_shared_scalar(&self, index: &GateIndex) -> &Vec<Scalar<C>> {
self.get(index)
.unwrap()
.try_as_secret_shared_scalar()
.unwrap()
}
fn get_secret_shared_base(&self, index: &GateIndex) -> &Vec<BaseFieldElement<C>> {
self.get(index)
.unwrap()
.try_as_secret_shared_base()
.unwrap()
}
fn get_secret_shared_bit(&self, index: &GateIndex) -> &Vec<bool> {
self.get(index).unwrap().try_as_secret_shared_bit().unwrap()
}
fn get_plaintext_scalar(&self, index: &GateIndex) -> &Vec<Scalar<C>> {
self.get(index).unwrap().try_as_plaintext_scalar().unwrap()
}
fn get_plaintext_base(&self, index: &GateIndex) -> &Vec<BaseFieldElement<C>> {
self.get(index).unwrap().try_as_plaintext_base().unwrap()
}
fn get_secret_shared_mersenne107(&self, index: &GateIndex) -> &Vec<Mersenne107Element> {
self.get(index)
.unwrap()
.try_as_secret_shared_mersenne107()
.unwrap()
}
fn get_plaintext_mersenne107(&self, index: &GateIndex) -> &Vec<Mersenne107Element> {
self.get(index)
.unwrap()
.try_as_plaintext_mersenne107()
.unwrap()
}
fn get_plaintext_bit(&self, index: &GateIndex) -> &Vec<bool> {
self.get(index).unwrap().try_as_plaintext_bit().unwrap()
}
fn get_da_bit(&self, index: &GateIndex) -> &Vec<bool> {
self.get(index).unwrap().try_as_da_bit().unwrap()
}
fn get_plaintext_point(&self, index: &GateIndex) -> &Vec<Point<C>> {
self.get(index).unwrap().try_as_plaintext_point().unwrap()
}
fn get_secret_shared_point(&self, index: &GateIndex) -> &Vec<Point<C>> {
self.get(index)
.unwrap()
.try_as_secret_shared_point()
.unwrap()
}
}
impl<C: Curve> Circuit<C> {
fn mock_eval_generate_randomness<R: MockRng + ?Sized>(
&self,
rng: &mut R,
) -> HashMap<GateIndex, RunTimeVal<C>> {
let mut dabit_gates_by_type = BTreeMap::<FieldType, Vec<GateIndex>>::new();
let mut random_gates_by_type = BTreeMap::<AlgebraicType, Vec<GateIndex>>::new();
for (index, gate) in izip_eq!(0..self.nb_gates(), self.iter_gates()) {
match gate {
Gate::Random { algebraic_type, .. } => {
random_gates_by_type
.entry(*algebraic_type)
.or_default()
.push(index);
}
Gate::DaBit { field_type, .. } => {
dabit_gates_by_type
.entry(*field_type)
.or_default()
.push(index);
}
_ => {}
}
}
let mut random_values = HashMap::<GateIndex, RunTimeVal<C>>::new();
for (algebraic_type, indices) in random_gates_by_type {
for index in indices {
let len = self.gate_output_unchecked(index).get_batch_size();
let val = match algebraic_type {
AlgebraicType::ScalarField => RunTimeVal::SecretSharedScalar(
(0..len).map(|_| rng.gen_scalar::<C>()).collect(),
),
AlgebraicType::BaseField => RunTimeVal::SecretSharedBase(
(0..len).map(|_| rng.gen_base::<C>()).collect(),
),
AlgebraicType::Bit => {
RunTimeVal::SecretSharedBit((0..len).map(|_| rng.gen_bit()).collect())
}
_ => todo!(),
};
random_values.insert(index, val);
}
}
for (_, indices) in dabit_gates_by_type {
for index in indices {
let len = self.gate_output_unchecked(index).get_batch_size();
let val = RunTimeVal::DaBit((0..len).map(|_| rng.gen_da_bit()).collect());
random_values.insert(index, val);
}
}
random_values
}
fn mock_eval<R: MockRng + ?Sized>(
&self,
inputs: &[RunTimeVal<C>],
rng: &mut R,
) -> Vec<RunTimeVal<C>> {
let mut random_values = self.mock_eval_generate_randomness(rng);
let mut wire_values = IntermediateValues::default();
let mut inputs = inputs.to_vec();
inputs.reverse();
for (index, GateExt { gate, output, .. }) in
izip_eq!(0..self.nb_gates(), self.iter_gates_ext())
{
let res = match gate {
Gate::FieldShareUnaryOp { x, op } => match output.get_field_type_unchecked() {
FieldType::ScalarField => {
let x = wire_values.get_secret_shared_scalar(x);
match op {
FieldShareUnaryOp::Neg => {
RunTimeVal::SecretSharedScalar(x.iter().map(|s| -s).collect())
}
FieldShareUnaryOp::MulInverse => RunTimeVal::SecretSharedScalar(
x.iter()
.map(|s| s.invert().unwrap_or(Scalar::<C>::ZERO))
.collect(),
),
FieldShareUnaryOp::Open => RunTimeVal::PlaintextScalar(x.clone()),
FieldShareUnaryOp::IsZero => RunTimeVal::PlaintextScalar(
x.iter().map(|s| (*s == Scalar::<C>::ZERO).into()).collect(),
),
}
}
FieldType::BaseField => {
let x = wire_values.get_secret_shared_base(x);
match op {
FieldShareUnaryOp::Neg => {
RunTimeVal::SecretSharedBase(x.iter().map(|s| -s).collect())
}
FieldShareUnaryOp::MulInverse => RunTimeVal::SecretSharedBase(
x.iter()
.map(|s| s.invert().unwrap_or(BaseFieldElement::<C>::ZERO))
.collect(),
),
FieldShareUnaryOp::Open => RunTimeVal::PlaintextBase(x.clone()),
FieldShareUnaryOp::IsZero => RunTimeVal::PlaintextBase(
x.iter()
.map(|s| (*s == BaseFieldElement::<C>::ZERO).into())
.collect(),
),
}
}
FieldType::Mersenne107 => {
let x = wire_values.get_secret_shared_mersenne107(x);
match op {
FieldShareUnaryOp::Neg => {
RunTimeVal::SecretSharedMersenne107(x.iter().map(|s| -s).collect())
}
FieldShareUnaryOp::MulInverse => RunTimeVal::SecretSharedMersenne107(
x.iter()
.map(|s| s.invert().unwrap_or(Mersenne107Element::ZERO))
.collect(),
),
FieldShareUnaryOp::Open => RunTimeVal::PlaintextMersenne107(x.clone()),
FieldShareUnaryOp::IsZero => RunTimeVal::PlaintextMersenne107(
x.iter()
.map(|s| (*s == Mersenne107Element::ZERO).into())
.collect(),
),
}
}
},
Gate::FieldShareBinaryOp { x, y, op } => {
let y_form = self.gate_output_unchecked(*y).get_form();
match output.get_field_type_unchecked() {
FieldType::ScalarField => {
let x = wire_values.get_secret_shared_scalar(x);
let y = match y_form {
ShareOrPlaintext::Share => wire_values.get_secret_shared_scalar(y),
ShareOrPlaintext::Plaintext => wire_values.get_plaintext_scalar(y),
};
RunTimeVal::SecretSharedScalar(match op {
FieldShareBinaryOp::Add => {
izip_eq!(x, y).map(|(a, b)| a + b).collect()
}
FieldShareBinaryOp::Mul => {
izip_eq!(x, y).map(|(a, b)| a * b).collect()
}
})
}
FieldType::BaseField => {
let x = wire_values.get_secret_shared_base(x);
let y = match y_form {
ShareOrPlaintext::Share => wire_values.get_secret_shared_base(y),
ShareOrPlaintext::Plaintext => wire_values.get_plaintext_base(y),
};
RunTimeVal::SecretSharedBase(match op {
FieldShareBinaryOp::Add => {
izip_eq!(x, y).map(|(a, b)| a + b).collect()
}
FieldShareBinaryOp::Mul => {
izip_eq!(x, y).map(|(a, b)| a * b).collect()
}
})
}
FieldType::Mersenne107 => {
let x = wire_values.get_secret_shared_mersenne107(x);
let y = match y_form {
ShareOrPlaintext::Share => {
wire_values.get_secret_shared_mersenne107(y)
}
ShareOrPlaintext::Plaintext => {
wire_values.get_plaintext_mersenne107(y)
}
};
RunTimeVal::SecretSharedMersenne107(match op {
FieldShareBinaryOp::Add => {
izip_eq!(x, y).map(|(a, b)| a + b).collect()
}
FieldShareBinaryOp::Mul => {
izip_eq!(x, y).map(|(a, b)| a * b).collect()
}
})
}
}
}
Gate::FieldPlaintextUnaryOp { x, op } => match output.get_field_type_unchecked() {
FieldType::ScalarField => {
let x = wire_values.get_plaintext_scalar(x);
let vals: Result<Vec<_>, _> =
x.iter().map(|val| op.eval(index, val)).collect();
RunTimeVal::PlaintextScalar(vals.unwrap())
}
FieldType::BaseField => {
let x = wire_values.get_plaintext_base(x);
let vals: Result<Vec<_>, _> =
x.iter().map(|val| op.eval(index, val)).collect();
RunTimeVal::PlaintextBase(vals.unwrap())
}
FieldType::Mersenne107 => {
let x = wire_values.get_plaintext_mersenne107(x);
let vals: Result<Vec<_>, _> =
x.iter().map(|val| op.eval(index, val)).collect();
RunTimeVal::PlaintextMersenne107(vals.unwrap())
}
},
Gate::FieldPlaintextBinaryOp { x, y, op } => {
match output.get_field_type_unchecked() {
FieldType::ScalarField => {
let x = wire_values.get_plaintext_scalar(x);
let y = wire_values.get_plaintext_scalar(y);
let vals: Result<Vec<_>, _> = izip_eq!(x, y)
.map(|(x_val, y_val)| op.eval(x_val, y_val, index))
.collect();
RunTimeVal::PlaintextScalar(vals.unwrap())
}
FieldType::BaseField => {
let x_val = wire_values.get_plaintext_base(x);
let y_val = wire_values.get_plaintext_base(y);
let vals: Result<Vec<_>, _> = izip_eq!(x_val, y_val)
.map(|(x_val, y_val)| op.eval(x_val, y_val, index))
.collect();
RunTimeVal::PlaintextBase(vals.unwrap())
}
FieldType::Mersenne107 => {
let x = wire_values.get_plaintext_mersenne107(x);
let y = wire_values.get_plaintext_mersenne107(y);
let vals: Result<Vec<_>, _> = izip_eq!(x, y)
.map(|(x_val, y_val)| op.eval(x_val, y_val, index))
.collect();
RunTimeVal::PlaintextMersenne107(vals.unwrap())
}
}
}
Gate::BitPlaintextUnaryOp { x, op } => {
let x = wire_values.get_plaintext_bit(x);
let vals = x
.iter()
.map(|val| op.eval(Bit::from(*val)).inner().into())
.collect::<Vec<bool>>();
RunTimeVal::PlaintextBit(vals)
}
Gate::BitPlaintextBinaryOp { x, y, op } => {
let x = wire_values.get_plaintext_bit(x);
let y = wire_values.get_plaintext_bit(y);
let vals = izip!(x, y)
.map(|(x_val, y_val)| {
op.eval(Bit::from(*x_val), Bit::from(*y_val)).inner().into()
})
.collect::<Vec<bool>>();
RunTimeVal::PlaintextBit(vals)
}
Gate::BitShareUnaryOp { x, op } => {
let x = wire_values.get_secret_shared_bit(x);
match op {
BitShareUnaryOp::Not => {
let vals = x.iter().map(|x| !x).collect::<Vec<_>>();
RunTimeVal::SecretSharedBit(vals)
}
BitShareUnaryOp::Open => RunTimeVal::PlaintextBit(x.clone()),
}
}
Gate::BitShareBinaryOp { x, y, op } => {
let x = wire_values.get_secret_shared_bit(x);
let y_form = self.gate_output_unchecked(*y).get_form();
let y = match y_form {
ShareOrPlaintext::Share => wire_values.get_secret_shared_bit(y),
ShareOrPlaintext::Plaintext => wire_values.get_plaintext_bit(y),
};
RunTimeVal::SecretSharedBit(match op {
BitShareBinaryOp::And => {
izip!(x, y).map(|(x_val, y_val)| x_val & y_val).collect()
}
BitShareBinaryOp::Xor => {
izip!(x, y).map(|(x_val, y_val)| x_val ^ y_val).collect()
}
BitShareBinaryOp::Or => {
izip!(x, y).map(|(x_val, y_val)| x_val | y_val).collect()
}
})
}
Gate::DaBit { .. } => random_values.remove(&index).expect("DaBit value not found"),
Gate::GetDaBitSharedBit { x } => {
let x_val = wire_values.get_da_bit(x);
RunTimeVal::SecretSharedBit(x_val.clone())
}
Gate::GetDaBitFieldShare { x } => {
let x_val = wire_values.get_da_bit(x);
match output.get_field_type_unchecked() {
FieldType::ScalarField => RunTimeVal::SecretSharedScalar(
x_val.iter().map(|b| Scalar::<C>::from(*b)).collect(),
),
FieldType::BaseField => RunTimeVal::SecretSharedBase(
x_val
.iter()
.map(|b| BaseFieldElement::<C>::from(*b))
.collect(),
),
FieldType::Mersenne107 => RunTimeVal::SecretSharedMersenne107(
x_val.iter().map(|b| Mersenne107Element::from(*b)).collect(),
),
}
}
Gate::BitPlaintextToField { x, field_type } => {
let x_val = wire_values.get_plaintext_bit(x);
match field_type {
FieldType::ScalarField => RunTimeVal::PlaintextScalar(
x_val.iter().map(|b| Scalar::<C>::from(*b)).collect(),
),
FieldType::BaseField => RunTimeVal::PlaintextBase(
x_val
.iter()
.map(|b| BaseFieldElement::<C>::from(*b))
.collect(),
),
FieldType::Mersenne107 => RunTimeVal::PlaintextMersenne107(
x_val.iter().map(|b| Mersenne107Element::from(*b)).collect(),
),
}
}
Gate::FieldPlaintextToBit { x } => match self
.gate_output_unchecked(*x)
.get_field_type_unchecked()
{
FieldType::ScalarField => {
let x = wire_values.get_plaintext_scalar(x);
let vals: Vec<bool> = x
.iter()
.map(|s| {
if s.is_zero().into() {
false
} else if s.is_one() {
true
} else {
panic!("Cannot convert non-binary scalar value to Bit");
}
})
.collect();
RunTimeVal::PlaintextBit(vals)
}
FieldType::BaseField => {
let x = wire_values.get_plaintext_base(x);
let vals: Vec<bool> = x
.iter()
.map(|s| {
if s.is_zero().into() {
false
} else if s.is_one() {
true
} else {
panic!("Cannot convert non-binary base value to Bit");
}
})
.collect();
RunTimeVal::PlaintextBit(vals)
}
FieldType::Mersenne107 => {
let x = wire_values.get_plaintext_mersenne107(x);
let vals: Vec<bool> = x
.iter()
.map(|s| {
if s.is_zero().into() {
false
} else if s.is_one() {
true
} else {
panic!("Cannot convert non-binary mersenne107 value to Bit");
}
})
.collect();
RunTimeVal::PlaintextBit(vals)
}
},
Gate::Input(input_type) => {
let val = inputs.pop().unwrap();
match input_type {
Input::SecretPlaintext { algebraic_type, .. }
| Input::Share { algebraic_type, .. } => match algebraic_type {
AlgebraicType::ScalarField => {
assert!(val.is_secret_shared_scalar());
}
AlgebraicType::BaseField => {
assert!(val.is_secret_shared_base());
}
AlgebraicType::Bit => {
assert!(val.is_secret_shared_bit());
}
AlgebraicType::Mersenne107 => {
assert!(val.is_secret_shared_mersenne107());
}
AlgebraicType::Point => {
assert!(val.is_secret_shared_point());
}
},
Input::Plaintext { algebraic_type, .. } => match algebraic_type {
AlgebraicType::ScalarField => {
assert!(val.is_plaintext_scalar());
}
AlgebraicType::BaseField => {
assert!(val.is_plaintext_base());
}
AlgebraicType::Bit => {
assert!(val.is_plaintext_bit());
}
AlgebraicType::Mersenne107 => {
assert!(val.is_plaintext_mersenne107());
}
AlgebraicType::Point => {
assert!(val.is_plaintext_point());
}
},
}
val
}
Gate::Constant(constant) => match constant {
Constant::Scalar(c) => RunTimeVal::PlaintextScalar(vec![*c]),
Constant::ScalarBatch(c) => RunTimeVal::PlaintextScalar(c.clone()),
Constant::BaseField(c) => RunTimeVal::PlaintextBase(vec![*c]),
Constant::BaseFieldBatch(c) => RunTimeVal::PlaintextBase(c.clone()),
Constant::Mersenne107(c) => RunTimeVal::PlaintextMersenne107(vec![*c]),
Constant::Mersenne107Batch(c) => RunTimeVal::PlaintextMersenne107(c.clone()),
Constant::Bit(c) => RunTimeVal::PlaintextBit(vec![c.inner().into()]),
Constant::BitBatch(c) => {
RunTimeVal::PlaintextBit(c.iter().map(|b| b.inner().into()).collect())
}
Constant::Point(c) => RunTimeVal::PlaintextPoint(vec![*c]),
Constant::PointBatch(c) => RunTimeVal::PlaintextPoint(c.clone()),
},
Gate::Random { .. } => random_values
.remove(&index)
.expect("Random gate value not found"),
Gate::BatchSummation { x } => {
let x_form = output.get_form();
match output.get_type() {
AlgebraicType::Mersenne107 => match x_form {
ShareOrPlaintext::Share => {
let x = wire_values.get_secret_shared_mersenne107(x);
let res = x.iter().sum();
RunTimeVal::SecretSharedMersenne107(vec![res])
}
ShareOrPlaintext::Plaintext => {
let x = wire_values.get_plaintext_mersenne107(x);
let res = x.iter().sum();
RunTimeVal::PlaintextMersenne107(vec![res])
}
},
AlgebraicType::BaseField => match x_form {
ShareOrPlaintext::Share => {
let x = wire_values.get_secret_shared_base(x);
let res = x.iter().sum();
RunTimeVal::SecretSharedBase(vec![res])
}
ShareOrPlaintext::Plaintext => {
let x = wire_values.get_plaintext_base(x);
let res = x.iter().sum();
RunTimeVal::PlaintextBase(vec![res])
}
},
AlgebraicType::ScalarField => match x_form {
ShareOrPlaintext::Share => {
let x = wire_values.get_secret_shared_scalar(x);
let res = x.iter().sum();
RunTimeVal::SecretSharedScalar(vec![res])
}
ShareOrPlaintext::Plaintext => {
let x = wire_values.get_plaintext_scalar(x);
let res = x.iter().sum();
RunTimeVal::PlaintextScalar(vec![res])
}
},
AlgebraicType::Point => match x_form {
ShareOrPlaintext::Share => {
let x = wire_values.get_secret_shared_point(x);
let res = x.iter().sum();
RunTimeVal::SecretSharedPoint(vec![res])
}
ShareOrPlaintext::Plaintext => {
let x = wire_values.get_plaintext_point(x);
let res = x.iter().sum();
RunTimeVal::PlaintextPoint(vec![res])
}
},
AlgebraicType::Bit => match x_form {
ShareOrPlaintext::Share => {
let x = wire_values.get_secret_shared_bit(x);
let res = x.iter().fold(false, |acc, x| acc ^ x);
RunTimeVal::SecretSharedBit(vec![res])
}
ShareOrPlaintext::Plaintext => {
let x = wire_values.get_plaintext_bit(x);
let res = x.iter().fold(false, |acc, x| acc ^ x);
RunTimeVal::PlaintextBit(vec![res])
}
},
}
}
Gate::PointShareUnaryOp { p, op } => {
let p = wire_values.get_secret_shared_point(p);
match op {
PointShareUnaryOp::Neg => {
RunTimeVal::SecretSharedPoint(p.iter().map(|x| -x).collect())
}
PointShareUnaryOp::Open => RunTimeVal::PlaintextPoint(p.clone()),
PointShareUnaryOp::IsZero => RunTimeVal::PlaintextScalar(
p.iter()
.map(|x| Scalar::<C>::from(*x == Point::identity()))
.collect(),
),
}
}
Gate::PointShareBinaryOp { p, y, op } => {
let p_form = self.gate_output_unchecked(*p).form;
let y_form = self.gate_output_unchecked(*y).form;
let p = match p_form {
ShareOrPlaintext::Share => wire_values.get_secret_shared_point(p),
ShareOrPlaintext::Plaintext => wire_values.get_plaintext_point(p),
};
match op {
PointShareBinaryOp::Add => {
let y = match y_form {
ShareOrPlaintext::Share => wire_values.get_secret_shared_point(y),
ShareOrPlaintext::Plaintext => wire_values.get_plaintext_point(y),
};
assert_eq!(p.len(), y.len());
let v = p.iter().zip(y).map(|(p, y)| p + y).collect();
RunTimeVal::SecretSharedPoint(v)
}
PointShareBinaryOp::ScalarMul => {
let y = match y_form {
ShareOrPlaintext::Share => wire_values.get_secret_shared_scalar(y),
ShareOrPlaintext::Plaintext => wire_values.get_plaintext_scalar(y),
};
assert_eq!(p.len(), y.len());
let v = p.iter().zip(y).map(|(p, y)| y * p).collect();
RunTimeVal::SecretSharedPoint(v)
}
}
}
Gate::PointPlaintextUnaryOp { p, op } => {
let p = wire_values.get_plaintext_point(p);
let res = p.iter().map(|p| op.eval(p).unwrap()).collect();
RunTimeVal::PlaintextPoint(res)
}
Gate::PointPlaintextBinaryOp { p, y, op } => {
let p = wire_values.get_plaintext_point(p);
match op {
PointPlaintextBinaryOp::Add => {
let y = wire_values.get_plaintext_point(y);
let res = izip_eq!(p, y)
.map(|(p, y)| op.eval(p, y).unwrap())
.collect();
RunTimeVal::PlaintextPoint(res)
}
PointPlaintextBinaryOp::ScalarMul => {
let y = wire_values.get_plaintext_scalar(y);
let res = izip_eq!(p, y).map(|(p, y)| y * p).collect();
RunTimeVal::PlaintextPoint(res)
}
}
}
Gate::BaseFieldPow { x, exp } => {
let x = wire_values.get_secret_shared_base(x);
RunTimeVal::SecretSharedBase(x.iter().map(|x| x.pow(exp)).collect())
}
Gate::ExtractFromBatch { x, slice } => {
let indices = slice.get_indices();
if let Gate::DaBit { .. } = self.gate_unchecked(*x) {
let x_val = wire_values.get_da_bit(x);
let x_val = indices.into_iter().map(|i| x_val[i as usize]).collect_vec();
RunTimeVal::DaBit(x_val)
} else {
let x_form = output.form;
match output.algebraic_type {
AlgebraicType::Mersenne107 => match x_form {
ShareOrPlaintext::Share => {
let x = wire_values.get_secret_shared_mersenne107(x);
RunTimeVal::SecretSharedMersenne107(
indices.into_iter().map(|i| x[i as usize]).collect(),
)
}
ShareOrPlaintext::Plaintext => {
let x = wire_values.get_plaintext_mersenne107(x);
RunTimeVal::PlaintextMersenne107(
indices.into_iter().map(|i| x[i as usize]).collect(),
)
}
},
AlgebraicType::BaseField => match x_form {
ShareOrPlaintext::Share => {
let x = wire_values.get_secret_shared_base(x);
RunTimeVal::SecretSharedBase(
indices.into_iter().map(|i| x[i as usize]).collect(),
)
}
ShareOrPlaintext::Plaintext => {
let x = wire_values.get_plaintext_base(x);
RunTimeVal::PlaintextBase(
indices.into_iter().map(|i| x[i as usize]).collect(),
)
}
},
AlgebraicType::ScalarField => match x_form {
ShareOrPlaintext::Share => {
let x = wire_values.get_secret_shared_scalar(x);
RunTimeVal::SecretSharedScalar(
indices.into_iter().map(|i| x[i as usize]).collect(),
)
}
ShareOrPlaintext::Plaintext => {
let x = wire_values.get_plaintext_scalar(x);
RunTimeVal::PlaintextScalar(
indices.into_iter().map(|i| x[i as usize]).collect(),
)
}
},
AlgebraicType::Point => match x_form {
ShareOrPlaintext::Share => {
let x = wire_values.get_secret_shared_point(x);
RunTimeVal::SecretSharedPoint(
indices.into_iter().map(|i| x[i as usize]).collect(),
)
}
ShareOrPlaintext::Plaintext => {
let x = wire_values.get_plaintext_point(x);
RunTimeVal::PlaintextPoint(
indices.into_iter().map(|i| x[i as usize]).collect(),
)
}
},
AlgebraicType::Bit => match x_form {
ShareOrPlaintext::Share => {
let x = wire_values.get_secret_shared_bit(x);
RunTimeVal::SecretSharedBit(
indices.into_iter().map(|i| x[i as usize]).collect(),
)
}
ShareOrPlaintext::Plaintext => {
let x = wire_values.get_plaintext_bit(x);
RunTimeVal::PlaintextBit(
indices.into_iter().map(|i| x[i as usize]).collect(),
)
}
},
}
}
}
Gate::CollectToBatch { wires } => {
let x_type = output.algebraic_type;
let x_form = output.form;
match x_type {
AlgebraicType::Mersenne107 => match x_form {
ShareOrPlaintext::Share => RunTimeVal::SecretSharedMersenne107(
wires
.iter()
.flat_map(|x| wire_values.get_secret_shared_mersenne107(x))
.copied()
.collect(),
),
ShareOrPlaintext::Plaintext => RunTimeVal::PlaintextMersenne107(
wires
.iter()
.flat_map(|x| wire_values.get_plaintext_mersenne107(x))
.copied()
.collect(),
),
},
AlgebraicType::BaseField => match x_form {
ShareOrPlaintext::Share => RunTimeVal::SecretSharedBase(
wires
.iter()
.flat_map(|x| wire_values.get_secret_shared_base(x))
.copied()
.collect(),
),
ShareOrPlaintext::Plaintext => RunTimeVal::PlaintextBase(
wires
.iter()
.flat_map(|x| wire_values.get_plaintext_base(x))
.copied()
.collect(),
),
},
AlgebraicType::ScalarField => match x_form {
ShareOrPlaintext::Share => RunTimeVal::SecretSharedScalar(
wires
.iter()
.flat_map(|x| wire_values.get_secret_shared_scalar(x))
.copied()
.collect(),
),
ShareOrPlaintext::Plaintext => RunTimeVal::PlaintextScalar(
wires
.iter()
.flat_map(|x| wire_values.get_plaintext_scalar(x))
.copied()
.collect(),
),
},
AlgebraicType::Point => match x_form {
ShareOrPlaintext::Share => RunTimeVal::SecretSharedPoint(
wires
.iter()
.flat_map(|x| wire_values.get_secret_shared_point(x))
.copied()
.collect(),
),
ShareOrPlaintext::Plaintext => RunTimeVal::PlaintextPoint(
wires
.iter()
.flat_map(|x| wire_values.get_plaintext_point(x))
.copied()
.collect(),
),
},
AlgebraicType::Bit => match x_form {
ShareOrPlaintext::Share => RunTimeVal::SecretSharedBit(
wires
.iter()
.flat_map(|x| wire_values.get_secret_shared_bit(x))
.copied()
.collect(),
),
ShareOrPlaintext::Plaintext => RunTimeVal::PlaintextBit(
wires
.iter()
.flat_map(|x| wire_values.get_plaintext_bit(x))
.copied()
.collect(),
),
},
}
}
Gate::PointFromPlaintextExtendedEdwards { wires } => {
let coordinates = wires
.iter()
.map(|x| wire_values.get_plaintext_base(x)[0])
.collect_vec();
RunTimeVal::PlaintextPoint(vec![Point::from_extended_edwards(
coordinates.try_into().expect("Extended Edwards coordinates must consist of a quadruple of BaseFieldElements.")
).unwrap_or_else(|| Point::identity())])
}
Gate::PlaintextPointToExtendedEdwards { point } => {
let point = wire_values.get_plaintext_point(point)[0];
RunTimeVal::PlaintextBase(
point
.to_extended_edwards()
.map_err(|_| ())
.expect("point did not convert to edwards")
.to_vec(),
)
}
Gate::PlaintextKeccakF1600 { x } => {
let bits = wire_values.get_plaintext_bit(x);
let mut state = [0u64; 25];
bits.chunks(64).enumerate().for_each(|(i, chunk)| {
let mut val = 0u64;
for (j, c) in chunk.iter().enumerate().take(64) {
val |= u64::from(*c) << j;
}
state[i] = val;
});
keccak::f1600(&mut state);
RunTimeVal::PlaintextBit(
state
.into_iter()
.flat_map(|val| {
std::array::from_fn::<_, 64, _>(|j| (val >> j) & 1u64 == 1u64)
})
.collect::<Vec<bool>>(),
)
}
Gate::CompressPlaintextPoint { point } => {
let point = wire_values.get_plaintext_point(point)[0];
let compressed = point.to_bytes();
RunTimeVal::PlaintextBit(
compressed
.iter()
.flat_map(|byte| {
(0..8)
.map(|i| (byte >> i) & 1u8 == 1u8)
.collect::<Vec<bool>>()
})
.collect::<Vec<bool>>(),
)
}
Gate::KeyRecoveryPlaintextComputeErrors {
d_minus_one,
syndromes,
} => {
let d_minus_one = wire_values.get_plaintext_base(d_minus_one)[0];
let syndromes = wire_values.get_plaintext_base(syndromes);
RunTimeVal::PlaintextBase(
compute_errors::<C>(d_minus_one, syndromes)
.unwrap()
.to_vec(),
)
}
};
assert_eq!(res.len(), output.batch_size as usize);
wire_values.push(res)
}
self.iter_output_indices()
.map(|x| wire_values.get(x).cloned().unwrap())
.collect()
}
pub fn generate_random_inputs<R: MockRng>(&self, mut rng: R) -> Vec<BigUint> {
self.iter_input_indices()
.flat_map(|idx| {
let GateExt { gate, output, .. } = self.gate_ext_unchecked(*idx);
let batch_size = output.get_batch_size();
match gate {
Gate::Input(input_type) => match input_type {
Input::SecretPlaintext { algebraic_type, .. }
| Input::Share { algebraic_type, .. }
| Input::Plaintext { algebraic_type, .. } => match algebraic_type {
AlgebraicType::BaseField => (0..batch_size)
.map(|_| rng.gen_base::<C>().to_biguint())
.collect_vec(),
AlgebraicType::ScalarField => (0..batch_size)
.map(|_| rng.gen_scalar::<C>().to_biguint())
.collect_vec(),
AlgebraicType::Point => (0..batch_size)
.map(|_| point_to_big_uint(rng.gen_point::<C>()))
.collect_vec(),
AlgebraicType::Bit => (0..batch_size)
.map(|_| bit_to_big_uint(rng.gen_bit()))
.collect_vec(),
AlgebraicType::Mersenne107 => (0..batch_size)
.map(|_| rng.gen_mersenne().to_biguint())
.collect_vec(),
},
},
_ => {
unreachable!("Gate must be an input gate")
}
}
})
.collect()
}
pub fn mock_eval_big_uint<R: MockRng + ?Sized>(
&self,
inputs: Vec<BigUint>,
rng: &mut R,
) -> Vec<BigUint> {
let mut inputs_iter = inputs.into_iter();
let v: Vec<RunTimeVal<C>> = self
.iter_input_indices()
.flat_map(|index| {
let GateExt { gate, output, .. } = self.gate_ext_unchecked(*index);
let res = match gate {
Gate::Input(input_type) => {
let inputs = inputs_iter.by_ref().take(output.batch_size as usize);
match input_type {
Input::SecretPlaintext { algebraic_type, .. }
| Input::Share { algebraic_type, .. } => match algebraic_type {
AlgebraicType::ScalarField => RunTimeVal::SecretSharedScalar(
inputs.map(big_uint_to_field::<ScalarField<C>>).collect(),
),
AlgebraicType::BaseField => RunTimeVal::SecretSharedBase(
inputs.map(big_uint_to_field::<BaseField<C>>).collect(),
),
AlgebraicType::Bit => RunTimeVal::SecretSharedBit(
inputs.map(big_uint_to_bit).collect(),
),
AlgebraicType::Mersenne107 => RunTimeVal::SecretSharedMersenne107(
inputs.map(big_uint_to_field::<Mersenne107>).collect(),
),
AlgebraicType::Point => RunTimeVal::SecretSharedPoint(
inputs.map(big_uint_to_point).collect(),
),
},
Input::Plaintext { algebraic_type, .. } => match algebraic_type {
AlgebraicType::BaseField => RunTimeVal::PlaintextBase(
inputs.map(big_uint_to_field::<BaseField<C>>).collect(),
),
AlgebraicType::ScalarField => RunTimeVal::PlaintextScalar(
inputs.map(big_uint_to_field::<ScalarField<C>>).collect(),
),
AlgebraicType::Point => RunTimeVal::PlaintextPoint(
inputs.map(big_uint_to_point::<C>).collect(),
),
AlgebraicType::Bit => {
RunTimeVal::PlaintextBit(inputs.map(big_uint_to_bit).collect())
}
AlgebraicType::Mersenne107 => RunTimeVal::PlaintextMersenne107(
inputs.map(big_uint_to_field::<Mersenne107>).collect(),
),
},
}
}
_ => panic!("Input gate {gate:?} is not actually an input."),
};
assert_eq!(output.batch_size as usize, res.len());
Some(res)
})
.collect();
self.mock_eval(&v, rng)
.into_iter()
.flat_map(|x| x.into_big_uint())
.collect()
}
}
pub trait MockRng {
fn gen_bit(&mut self) -> bool;
fn gen_da_bit(&mut self) -> bool;
fn gen_scalar<C: Curve>(&mut self) -> Scalar<C>;
fn gen_base<C: Curve>(&mut self) -> BaseFieldElement<C>;
fn gen_mersenne(&mut self) -> Mersenne107Element;
fn gen_point<C: Curve>(&mut self) -> Point<C>;
}
impl<R: Rng + ?Sized> MockRng for R {
fn gen_bit(&mut self) -> bool {
self.gen()
}
fn gen_da_bit(&mut self) -> bool {
self.gen()
}
fn gen_scalar<C: Curve>(&mut self) -> Scalar<C> {
Scalar::<C>::random(self)
}
fn gen_base<C: Curve>(&mut self) -> BaseFieldElement<C> {
BaseFieldElement::<C>::random(self)
}
fn gen_mersenne(&mut self) -> Mersenne107Element {
Mersenne107Element::random(self)
}
fn gen_point<C: Curve>(&mut self) -> Point<C> {
self.gen()
}
}
#[cfg(test)]
mod tests {
use primitives::{algebra::elliptic_curve::Curve25519Ristretto, random::Random};
use super::*;
#[test]
fn point_serialization() {
for _ in 0..4 {
let point = Point::<Curve25519Ristretto>::random(&mut rand::thread_rng());
let ser_point = RunTimeVal::PlaintextPoint(vec![point])
.into_big_uint()
.pop()
.unwrap();
let de_point = big_uint_to_point(ser_point);
assert_eq!(de_point, point);
}
}
}