use std::iter::repeat_n;
use ff::Field;
use num_traits::{One, Zero};
use primitives::{
algebra::{
elliptic_curve::{BaseFieldElement, Curve, Point, Scalar},
field::{subfield_element::Mersenne107Element, Bit, FieldExtension, SubfieldElement},
BoxedUint,
},
types::PeerNumber,
};
use serde::{Deserialize, Serialize};
use typenum::Unsigned;
use wincode::{SchemaRead, SchemaWrite};
use crate::{
circuit::{errors::BatchSizeError, AlgebraicType, BatchSize, GateIndex, ShareOrPlaintext},
errors::{AbortError, FaultyPeer},
};
#[derive(
Debug,
Clone,
PartialEq,
Eq,
Hash,
Serialize,
Deserialize,
SchemaRead,
SchemaWrite,
PartialOrd,
Ord,
)]
#[repr(C)]
pub enum FieldPlaintextUnaryOp {
Neg,
MulInverse,
BitExtract {
little_endian_bit_idx: u16,
signed: bool,
},
Sqrt,
Pow {
exp: BoxedUint,
},
}
impl FieldPlaintextUnaryOp {
pub fn eval<F: FieldExtension>(
&self,
label: GateIndex,
x: &SubfieldElement<F>,
) -> Result<SubfieldElement<F>, AbortError> {
match self {
FieldPlaintextUnaryOp::Neg => Ok(-x),
FieldPlaintextUnaryOp::MulInverse => {
Ok(x.invert().unwrap_or(SubfieldElement::<F>::zero()))
}
FieldPlaintextUnaryOp::BitExtract {
little_endian_bit_idx: idx,
signed,
} => {
let bit = if *signed && *x > -x {
!(-SubfieldElement::<F>::one() - x)
.to_biguint()
.bit(*idx as u64)
} else {
x.to_biguint().bit(*idx as u64)
};
Ok(SubfieldElement::<F>::from(bit))
}
FieldPlaintextUnaryOp::Sqrt => {
let (choice, sqrt) =
SubfieldElement::<F>::sqrt_ratio(x, &SubfieldElement::<F>::one());
if !bool::from(choice) {
return Err(AbortError::quadratic_non_residue(label, FaultyPeer::Local));
}
Ok(sqrt)
}
FieldPlaintextUnaryOp::Pow { exp } => Ok(x.pow(exp)),
}
}
}
#[derive(
Debug,
Clone,
Copy,
PartialEq,
Eq,
Hash,
Serialize,
Deserialize,
SchemaRead,
SchemaWrite,
PartialOrd,
Ord,
)]
#[repr(C)]
pub enum FieldPlaintextBinaryOp {
Add,
Mul,
EuclDiv,
Mod,
Gt,
Ge,
Eq,
Xor,
Or,
}
impl FieldPlaintextBinaryOp {
pub fn eval<F: FieldExtension>(
&self,
x: &SubfieldElement<F>,
y: &SubfieldElement<F>,
label: GateIndex,
) -> Result<SubfieldElement<F>, AbortError> {
match self {
FieldPlaintextBinaryOp::Add => Ok(x + y),
FieldPlaintextBinaryOp::Mul => Ok(x * y),
FieldPlaintextBinaryOp::EuclDiv => euclidean_division::<F>(x, y, label),
FieldPlaintextBinaryOp::Mod => modulo::<F>(x, y, label),
FieldPlaintextBinaryOp::Gt => Ok(SubfieldElement::<F>::from(x > y)),
FieldPlaintextBinaryOp::Ge => Ok(SubfieldElement::<F>::from(x >= y)),
FieldPlaintextBinaryOp::Eq => Ok(SubfieldElement::<F>::from(x == y)),
FieldPlaintextBinaryOp::Xor => Ok(x + y - SubfieldElement::<F>::from(2u32) * x * y),
FieldPlaintextBinaryOp::Or => Ok(x + y - x * y),
}
}
}
pub(crate) fn euclidean_division<F: FieldExtension>(
x: &SubfieldElement<F>,
y: &SubfieldElement<F>,
label: GateIndex,
) -> Result<SubfieldElement<F>, AbortError> {
if *y == SubfieldElement::<F>::zero() {
return Err(AbortError::division_by_zero(label, FaultyPeer::Local));
}
let x = x.to_biguint();
let y = y.to_biguint();
let div = (x / y).to_bytes_be();
let div = repeat_n(0, F::FieldBytesSize::USIZE - div.len())
.chain(div)
.collect::<Vec<_>>();
Ok(SubfieldElement::<F>::from_be_bytes(&div)?)
}
fn modulo<F: FieldExtension>(
x: &SubfieldElement<F>,
y: &SubfieldElement<F>,
label: GateIndex,
) -> Result<SubfieldElement<F>, AbortError> {
if *y == SubfieldElement::<F>::zero() {
return Err(AbortError::division_by_zero(label, FaultyPeer::Local));
}
let x = x.to_biguint();
let y = y.to_biguint();
let modulo = x.modpow(&num_bigint::BigUint::from(1u32), &y).to_bytes_be();
let modulo = repeat_n(0, F::FieldBytesSize::USIZE - modulo.len())
.chain(modulo)
.collect::<Vec<_>>();
Ok(SubfieldElement::<F>::from_be_bytes(&modulo)?)
}
#[derive(
Debug,
Clone,
Copy,
PartialEq,
Eq,
Hash,
Serialize,
Deserialize,
SchemaRead,
SchemaWrite,
PartialOrd,
Ord,
)]
#[repr(C)]
pub enum FieldShareUnaryOp {
Neg,
MulInverse,
Open,
IsZero,
}
#[derive(
Debug,
Clone,
Copy,
PartialEq,
Eq,
Hash,
Serialize,
Deserialize,
SchemaRead,
SchemaWrite,
PartialOrd,
Ord,
)]
#[repr(C)]
pub enum FieldShareBinaryOp {
Add,
Mul,
}
#[derive(
Debug,
Clone,
Copy,
PartialEq,
Eq,
Hash,
Serialize,
Deserialize,
SchemaRead,
SchemaWrite,
PartialOrd,
Ord,
)]
#[repr(C)]
pub enum BitShareUnaryOp {
Not,
Open,
}
#[derive(
Debug,
Clone,
Copy,
PartialEq,
Eq,
Hash,
Serialize,
Deserialize,
SchemaRead,
SchemaWrite,
PartialOrd,
Ord,
)]
#[repr(C)]
pub enum BitShareBinaryOp {
Xor,
Or,
And,
}
#[derive(
Debug,
Clone,
Copy,
PartialEq,
Eq,
Hash,
Serialize,
Deserialize,
SchemaRead,
SchemaWrite,
PartialOrd,
Ord,
)]
#[repr(C)]
pub enum BitPlaintextUnaryOp {
Not,
}
impl BitPlaintextUnaryOp {
pub fn eval(&self, x: Bit) -> Bit {
match self {
BitPlaintextUnaryOp::Not => Bit::ONE - x,
}
}
}
#[derive(
Debug,
Clone,
Copy,
PartialEq,
Eq,
Hash,
Serialize,
Deserialize,
SchemaRead,
SchemaWrite,
PartialOrd,
Ord,
)]
#[repr(C)]
pub enum BitPlaintextBinaryOp {
Xor,
Or,
And,
}
impl BitPlaintextBinaryOp {
pub fn eval(&self, x: Bit, y: Bit) -> Bit {
match self {
BitPlaintextBinaryOp::Xor => x + y,
BitPlaintextBinaryOp::Or => x + y - x * y,
BitPlaintextBinaryOp::And => x * y,
}
}
}
#[derive(
Debug,
Clone,
Copy,
PartialEq,
Eq,
Hash,
Serialize,
Deserialize,
SchemaRead,
SchemaWrite,
PartialOrd,
Ord,
)]
#[repr(C)]
pub enum PointPlaintextUnaryOp {
Neg,
}
impl PointPlaintextUnaryOp {
pub fn eval<C: Curve>(&self, x: &Point<C>) -> Result<Point<C>, AbortError> {
match self {
PointPlaintextUnaryOp::Neg => Ok(-x),
}
}
}
#[derive(
Debug,
Clone,
Copy,
PartialEq,
Eq,
Hash,
Serialize,
Deserialize,
SchemaRead,
SchemaWrite,
PartialOrd,
Ord,
)]
#[repr(C)]
pub enum PointPlaintextBinaryOp {
Add,
ScalarMul,
}
impl PointPlaintextBinaryOp {
pub fn eval<C: Curve>(&self, x: &Point<C>, y: &Point<C>) -> Result<Point<C>, AbortError> {
match self {
PointPlaintextBinaryOp::Add => Ok(x + y),
PointPlaintextBinaryOp::ScalarMul => Err(AbortError::internal_error(
"PointPlaintextBinaryOp::eval not supported for PointPlaintextBinaryOp::ScalarMul.",
)),
}
}
}
#[derive(
Debug,
Clone,
Copy,
PartialEq,
Eq,
Hash,
Serialize,
Deserialize,
SchemaRead,
SchemaWrite,
PartialOrd,
Ord,
)]
#[repr(C)]
pub enum PointShareUnaryOp {
Neg,
Open,
IsZero,
}
#[derive(
Debug,
Clone,
Copy,
PartialEq,
Eq,
Hash,
Serialize,
Deserialize,
SchemaRead,
SchemaWrite,
PartialOrd,
Ord,
)]
#[repr(C)]
pub enum PointShareBinaryOp {
Add,
ScalarMul,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, SchemaRead, SchemaWrite)]
#[repr(C)]
pub enum Input {
Plaintext {
algebraic_type: AlgebraicType,
batch_size: BatchSize,
},
SecretPlaintext {
inputer: PeerNumber,
algebraic_type: AlgebraicType,
batch_size: BatchSize,
},
Share {
algebraic_type: AlgebraicType,
batch_size: BatchSize,
},
}
impl Input {
pub fn batch_size(&self) -> u32 {
match self {
Input::Plaintext { batch_size, .. }
| Input::SecretPlaintext { batch_size, .. }
| Input::Share { batch_size, .. } => *batch_size,
}
}
pub fn algebraic_type(&self) -> AlgebraicType {
match self {
Input::Plaintext { algebraic_type, .. }
| Input::Share { algebraic_type, .. }
| Input::SecretPlaintext { algebraic_type, .. } => *algebraic_type,
}
}
pub fn share_or_plaintext(&self) -> ShareOrPlaintext {
match self {
Input::SecretPlaintext { .. } | Input::Share { .. } => ShareOrPlaintext::Share,
Input::Plaintext { .. } => ShareOrPlaintext::Plaintext,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, SchemaRead, SchemaWrite)]
#[serde(bound(
serialize = "Scalar<C>: Serialize, Point<C>: Serialize",
deserialize = "Scalar<C>: Deserialize<'de>, Point<C>: Deserialize<'de>"
))]
#[repr(C)]
pub enum Constant<C: Curve> {
Scalar(Scalar<C>),
ScalarBatch(Vec<Scalar<C>>),
BaseField(BaseFieldElement<C>),
BaseFieldBatch(Vec<BaseFieldElement<C>>),
Mersenne107(Mersenne107Element),
Mersenne107Batch(Vec<Mersenne107Element>),
Bit(Bit),
BitBatch(Vec<Bit>),
Point(Point<C>),
PointBatch(Vec<Point<C>>),
}
impl<C: Curve> Constant<C> {
pub fn batch_size(&self) -> Result<u32, BatchSizeError> {
let n = match self {
Constant::ScalarBatch(v) => v.len(),
Constant::BaseFieldBatch(v) => v.len(),
Constant::Mersenne107Batch(v) => v.len(),
Constant::BitBatch(v) => v.len(),
Constant::PointBatch(v) => v.len(),
Constant::Scalar(_)
| Constant::BaseField(_)
| Constant::Mersenne107(_)
| Constant::Bit(_)
| Constant::Point(_) => 1,
};
if let Ok(n) = u32::try_from(n) {
Ok(n)
} else {
Err(BatchSizeError(n))
}
}
pub fn algebraic_type(&self) -> AlgebraicType {
match self {
Constant::Scalar(_) | Constant::ScalarBatch(_) => AlgebraicType::ScalarField,
Constant::BaseField(_) | Constant::BaseFieldBatch(_) => AlgebraicType::BaseField,
Constant::Mersenne107(_) | Constant::Mersenne107Batch(_) => AlgebraicType::Mersenne107,
Constant::Bit(_) | Constant::BitBatch(_) => AlgebraicType::Bit,
Constant::Point(_) | Constant::PointBatch(_) => AlgebraicType::Point,
}
}
}
#[cfg(test)]
mod tests {
use primitives::algebra::{
elliptic_curve::{BaseField, Curve25519Ristretto as C, ScalarField},
field::SubfieldElement,
};
use super::*;
#[test]
fn test_scalar_unary_op() {
let mut rng = rand::thread_rng();
let x = SubfieldElement::<ScalarField<C>>::random(&mut rng);
let label = 0;
let neg = FieldPlaintextUnaryOp::Neg;
let mul_inverse = FieldPlaintextUnaryOp::MulInverse;
assert_eq!(neg.eval::<ScalarField<C>>(label, &x), Ok(-x));
assert_eq!(
mul_inverse.eval::<ScalarField<C>>(label, &x),
Ok(x.invert().unwrap())
);
}
#[test]
fn test_scalar_binary_op() {
let mut rng = rand::thread_rng();
let x = SubfieldElement::<ScalarField<C>>::random(&mut rng);
let y = SubfieldElement::<ScalarField<C>>::random(&mut rng);
let label = 0;
let add = FieldPlaintextBinaryOp::Add;
let mul = FieldPlaintextBinaryOp::Mul;
let eucl_div = FieldPlaintextBinaryOp::EuclDiv;
let modulo_op = FieldPlaintextBinaryOp::Mod;
let gt = FieldPlaintextBinaryOp::Gt;
let ge = FieldPlaintextBinaryOp::Ge;
let eq = FieldPlaintextBinaryOp::Eq;
assert_eq!(add.eval::<ScalarField<C>>(&x, &y, label), Ok(x + y));
assert_eq!(mul.eval::<ScalarField<C>>(&x, &y, label), Ok(x * y));
assert_eq!(
eucl_div.eval::<ScalarField<C>>(&x, &y, label),
euclidean_division::<ScalarField<C>>(&x, &y, label)
);
assert_eq!(
modulo_op.eval::<ScalarField<C>>(&x, &y, label),
modulo::<ScalarField<C>>(&x, &y, label)
);
assert_eq!(
gt.eval::<ScalarField<C>>(&x, &y, label),
Ok(SubfieldElement::<ScalarField<C>>::from(x > y))
);
assert_eq!(
ge.eval::<ScalarField<C>>(&x, &y, label),
Ok(SubfieldElement::<ScalarField<C>>::from(x >= y))
);
assert_eq!(
eq.eval::<ScalarField<C>>(&x, &y, label),
Ok(SubfieldElement::<ScalarField<C>>::from(x == y))
);
}
#[test]
fn test_scalar_boolean_binary_op() {
let and = FieldPlaintextBinaryOp::Mul;
let or = FieldPlaintextBinaryOp::Or;
let xor = FieldPlaintextBinaryOp::Xor;
let label = 0;
for bool_x in [false, true] {
for bool_y in [false, true] {
let scalar_x = SubfieldElement::<ScalarField<C>>::from(bool_x);
let scalar_y = SubfieldElement::<ScalarField<C>>::from(bool_y);
assert_eq!(
and.eval::<ScalarField<C>>(&scalar_x, &scalar_y, label),
Ok((bool_x && bool_y).into())
);
assert_eq!(
or.eval::<ScalarField<C>>(&scalar_x, &scalar_y, label),
Ok((bool_x || bool_y).into())
);
assert_eq!(
xor.eval::<ScalarField<C>>(&scalar_x, &scalar_y, label),
Ok((bool_x ^ bool_y).into())
);
}
}
}
#[test]
fn test_bit_ops() {
let not = BitPlaintextUnaryOp::Not;
for bool_x in [false, true] {
let x = Bit::from(bool_x);
assert_eq!(not.eval(x), (!bool_x).into());
}
let and = BitPlaintextBinaryOp::And;
let or = BitPlaintextBinaryOp::Or;
let xor = BitPlaintextBinaryOp::Xor;
for bool_x in [false, true] {
for bool_y in [false, true] {
let x = Bit::from(bool_x);
let y = Bit::from(bool_y);
assert_eq!(and.eval(x, y), (bool_x && bool_y).into());
assert_eq!(or.eval(x, y), (bool_x || bool_y).into());
assert_eq!(xor.eval(x, y), (bool_x ^ bool_y).into());
}
}
}
#[test]
fn test_euclidian_division() {
let x = SubfieldElement::<ScalarField<C>>::from(37u32);
let y = SubfieldElement::<ScalarField<C>>::from(12u32);
let label = 0;
let result = euclidean_division::<ScalarField<C>>(&x, &y, label).unwrap();
assert_eq!(result, SubfieldElement::<ScalarField<C>>::from(37u32 / 12));
}
#[test]
fn test_modulo() {
let x = SubfieldElement::<ScalarField<C>>::from(37u32);
let y = SubfieldElement::<ScalarField<C>>::from(12u32);
let label = 0;
let result = modulo::<ScalarField<C>>(&x, &y, label).unwrap();
assert_eq!(result, SubfieldElement::<ScalarField<C>>::from(37u32 % 12));
}
#[test]
fn test_signed_bit_extract() {
let x = -Scalar::<C>::from(9u32);
let label = 0;
for i in 0..5 {
let op = FieldPlaintextUnaryOp::BitExtract {
little_endian_bit_idx: i,
signed: true,
};
let result = op.eval::<ScalarField<C>>(label, &x);
assert_eq!(result.unwrap(), ((-9i32 >> i) & 1 == 1).into())
}
}
#[test]
fn test_sqrt() {
let mut rng = rand::thread_rng();
let x = SubfieldElement::<ScalarField<C>>::random(&mut rng);
let label = 0;
let result = FieldPlaintextUnaryOp::Sqrt
.eval::<ScalarField<C>>(label, &(x * x))
.unwrap();
assert_eq!(result * result, x * x)
}
#[test]
fn test_pow() {
let mut rng = rand::thread_rng();
let x = SubfieldElement::<BaseField<C>>::random(&mut rng);
let label = 0;
let five = BoxedUint::from(vec![5u64]);
let five_inv = BoxedUint::from(vec![
14757395258967641281,
14757395258967641292,
14757395258967641292,
5534023222112865484,
]);
let x_pow_5 = FieldPlaintextUnaryOp::Pow { exp: five }
.eval::<BaseField<C>>(label, &x)
.unwrap();
let x_again = FieldPlaintextUnaryOp::Pow { exp: five_inv }
.eval::<BaseField<C>>(label, &x_pow_5)
.unwrap();
assert_eq!(x_again, x)
}
}