use std::{
fmt::Debug,
iter::Sum,
ops::{Add, Mul, Neg, Sub},
pin::Pin,
task::{Context, Poll},
};
use ark_ec::CurveGroup;
use futures::{Future, FutureExt};
use itertools::{izip, Itertools};
use crate::{
commitment::{PedersenCommitment, PedersenCommitmentResult},
error::MpcError,
fabric::{MpcFabric, ResultId, ResultValue},
PARTY0,
};
use super::{
authenticated_curve::AuthenticatedPointResult,
curve::{CurvePoint, CurvePointResult},
macros::{impl_borrow_variants, impl_commutative},
mpc_scalar::MpcScalarResult,
scalar::{BatchScalarResult, Scalar, ScalarResult},
};
pub const AUTHENTICATED_SCALAR_RESULT_LEN: usize = 3;
#[derive(Clone)]
pub struct AuthenticatedScalarResult<C: CurveGroup> {
pub(crate) share: MpcScalarResult<C>,
pub(crate) mac: MpcScalarResult<C>,
pub(crate) public_modifier: ScalarResult<C>,
}
impl<C: CurveGroup> Debug for AuthenticatedScalarResult<C> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AuthenticatedScalarResult<C>")
.field("value", &self.share.id())
.field("mac", &self.mac.id())
.field("public_modifier", &self.public_modifier.id)
.finish()
}
}
impl<C: CurveGroup> AuthenticatedScalarResult<C> {
pub fn new_shared(value: ScalarResult<C>) -> Self {
let fabric = value.fabric.clone();
let mpc_value = MpcScalarResult::new_shared(value);
let mac = fabric.borrow_mac_key() * mpc_value.clone();
let public_modifier = fabric.zero();
Self {
share: mpc_value,
mac,
public_modifier,
}
}
pub fn new_shared_batch(values: &[ScalarResult<C>]) -> Vec<Self> {
if values.is_empty() {
return vec![];
}
let n = values.len();
let fabric = values[0].fabric();
let mpc_values = values
.iter()
.map(|v| MpcScalarResult::new_shared(v.clone()))
.collect_vec();
let mac_keys = (0..n)
.map(|_| fabric.borrow_mac_key().clone())
.collect_vec();
let values_macs = MpcScalarResult::batch_mul(&mpc_values, &mac_keys);
mpc_values
.into_iter()
.zip(values_macs.into_iter())
.map(|(value, mac)| Self {
share: value,
mac,
public_modifier: fabric.zero(),
})
.collect_vec()
}
pub fn new_shared_from_batch_result(
values: BatchScalarResult<C>,
n: usize,
) -> Vec<AuthenticatedScalarResult<C>> {
let scalar_results: Vec<ScalarResult<C>> =
values
.fabric()
.new_batch_gate_op(vec![values.id()], n, |mut args| {
let scalars: Vec<Scalar<C>> = args.pop().unwrap().into();
scalars.into_iter().map(ResultValue::Scalar).collect()
});
Self::new_shared_batch(&scalar_results)
}
#[cfg(feature = "test_helpers")]
pub fn mpc_share(&self) -> MpcScalarResult<C> {
self.share.clone()
}
pub fn share(&self) -> ScalarResult<C> {
self.share.to_scalar()
}
pub fn fabric(&self) -> &MpcFabric<C> {
self.share.fabric()
}
pub fn ids(&self) -> Vec<ResultId> {
vec![self.share.id(), self.mac.id(), self.public_modifier.id]
}
pub fn open(&self) -> ScalarResult<C> {
self.share.open()
}
pub fn open_batch(values: &[Self]) -> Vec<ScalarResult<C>> {
MpcScalarResult::open_batch(&values.iter().map(|val| val.share.clone()).collect_vec())
}
pub fn from_flattened_iterator<I>(iter: I) -> Vec<Self>
where
I: Iterator<Item = ScalarResult<C>>,
{
iter.chunks(AUTHENTICATED_SCALAR_RESULT_LEN)
.into_iter()
.map(|mut chunk| Self {
share: chunk.next().unwrap().into(),
mac: chunk.next().unwrap().into(),
public_modifier: chunk.next().unwrap(),
})
.collect_vec()
}
pub fn verify_mac_check(
my_mac_share: Scalar<C>,
peer_mac_share: Scalar<C>,
peer_mac_commitment: CurvePoint<C>,
peer_commitment_blinder: Scalar<C>,
) -> bool {
let their_comm = PedersenCommitment {
value: peer_mac_share,
blinder: peer_commitment_blinder,
commitment: peer_mac_commitment,
};
if !their_comm.verify() {
return false;
}
if peer_mac_share + my_mac_share != Scalar::zero() {
return false;
}
true
}
pub fn open_authenticated(&self) -> AuthenticatedScalarOpenResult<C> {
let recovered_value = self.share.open();
let mac_check_value: ScalarResult<C> = self.fabric().new_gate_op(
vec![
self.fabric().borrow_mac_key().id(),
recovered_value.id,
self.public_modifier.id,
self.mac.id(),
],
move |mut args| {
let mac_key_share: Scalar<C> = args.remove(0).into();
let value: Scalar<C> = args.remove(0).into();
let modifier: Scalar<C> = args.remove(0).into();
let mac_share: Scalar<C> = args.remove(0).into();
ResultValue::Scalar(mac_key_share * (value + modifier) - mac_share)
},
);
let my_comm = PedersenCommitmentResult::commit(mac_check_value);
let peer_commit = self.fabric().exchange_value(my_comm.commitment);
let peer_mac_check = self.fabric().exchange_value(my_comm.value.clone());
let blinder_result: ScalarResult<C> = self.fabric().allocate_scalar(my_comm.blinder);
let peer_blinder = self.fabric().exchange_value(blinder_result);
let commitment_check: ScalarResult<C> = self.fabric().new_gate_op(
vec![
my_comm.value.id,
peer_mac_check.id,
peer_blinder.id,
peer_commit.id,
],
|mut args| {
let my_comm_value: Scalar<C> = args.remove(0).into();
let peer_value: Scalar<C> = args.remove(0).into();
let blinder: Scalar<C> = args.remove(0).into();
let commitment: CurvePoint<C> = args.remove(0).into();
ResultValue::Scalar(Scalar::from(Self::verify_mac_check(
my_comm_value,
peer_value,
commitment,
blinder,
)))
},
);
AuthenticatedScalarOpenResult {
value: recovered_value,
mac_check: commitment_check,
}
}
pub fn open_authenticated_batch(values: &[Self]) -> Vec<AuthenticatedScalarOpenResult<C>> {
if values.is_empty() {
return vec![];
}
let n = values.len();
let fabric = &values[0].fabric();
let values_open = Self::open_batch(values);
let mut mac_check_deps = Vec::with_capacity(1 + 3 * n);
mac_check_deps.push(fabric.borrow_mac_key().id());
for i in 0..n {
mac_check_deps.push(values_open[i].id());
mac_check_deps.push(values[i].public_modifier.id());
mac_check_deps.push(values[i].mac.id());
}
let mac_checks: Vec<ScalarResult<C>> =
fabric.new_batch_gate_op(mac_check_deps, n , move |mut args| {
let mac_key_share: Scalar<C> = args.remove(0).into();
let mut check_result = Vec::with_capacity(n);
for _ in 0..n {
let value: Scalar<C> = args.remove(0).into();
let modifier: Scalar<C> = args.remove(0).into();
let mac_share: Scalar<C> = args.remove(0).into();
check_result.push(mac_key_share * (value + modifier) - mac_share);
}
check_result.into_iter().map(ResultValue::Scalar).collect()
});
let my_comms = mac_checks
.iter()
.cloned()
.map(PedersenCommitmentResult::commit)
.collect_vec();
let peer_comms = fabric.exchange_values(
&my_comms
.iter()
.map(|comm| comm.commitment.clone())
.collect_vec(),
);
let peer_mac_checks = fabric.exchange_values(&mac_checks);
let peer_blinders = fabric.exchange_values(
&my_comms
.iter()
.map(|comm| fabric.allocate_scalar(comm.blinder))
.collect_vec(),
);
let mut mac_check_gate_deps = my_comms.iter().map(|comm| comm.value.id).collect_vec();
mac_check_gate_deps.push(peer_mac_checks.id);
mac_check_gate_deps.push(peer_blinders.id);
mac_check_gate_deps.push(peer_comms.id);
let commitment_checks: Vec<ScalarResult<C>> = fabric.new_batch_gate_op(
mac_check_gate_deps,
n, move |mut args| {
let my_comms: Vec<Scalar<C>> = args.drain(..n).map(|comm| comm.into()).collect();
let peer_mac_checks: Vec<Scalar<C>> = args.remove(0).into();
let peer_blinders: Vec<Scalar<C>> = args.remove(0).into();
let peer_comms: Vec<CurvePoint<C>> = args.remove(0).into();
let mut mac_checks = Vec::with_capacity(n);
for (my_mac_share, peer_mac_share, peer_blinder, peer_commitment) in izip!(
my_comms.into_iter(),
peer_mac_checks.into_iter(),
peer_blinders.into_iter(),
peer_comms.into_iter()
) {
let mac_check = Self::verify_mac_check(
my_mac_share,
peer_mac_share,
peer_commitment,
peer_blinder,
);
mac_checks.push(ResultValue::Scalar(Scalar::from(mac_check)));
}
mac_checks
},
);
values_open
.into_iter()
.zip(commitment_checks.into_iter())
.map(|(value, check)| AuthenticatedScalarOpenResult {
value,
mac_check: check,
})
.collect_vec()
}
}
#[derive(Clone)]
pub struct AuthenticatedScalarOpenResult<C: CurveGroup> {
pub value: ScalarResult<C>,
pub mac_check: ScalarResult<C>,
}
impl<C: CurveGroup> Future for AuthenticatedScalarOpenResult<C>
where
C::ScalarField: Unpin,
{
type Output = Result<Scalar<C>, MpcError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let value = futures::ready!(self.as_mut().value.poll_unpin(cx));
let mac_check = futures::ready!(self.as_mut().mac_check.poll_unpin(cx));
if mac_check == Scalar::from(1u8) {
Poll::Ready(Ok(value))
} else {
Poll::Ready(Err(MpcError::AuthenticationError))
}
}
}
impl<C: CurveGroup> Add<&Scalar<C>> for &AuthenticatedScalarResult<C> {
type Output = AuthenticatedScalarResult<C>;
fn add(self, rhs: &Scalar<C>) -> Self::Output {
let new_share = if self.fabric().party_id() == PARTY0 {
&self.share + rhs
} else {
&self.share + Scalar::zero()
};
let new_modifier = &self.public_modifier - rhs;
AuthenticatedScalarResult {
share: new_share,
mac: self.mac.clone(),
public_modifier: new_modifier,
}
}
}
impl_borrow_variants!(AuthenticatedScalarResult<C>, Add, add, +, Scalar<C>, Output=AuthenticatedScalarResult<C>, C: CurveGroup);
impl_commutative!(AuthenticatedScalarResult<C>, Add, add, +, Scalar<C>, Output=AuthenticatedScalarResult<C>, C: CurveGroup);
impl<C: CurveGroup> Add<&ScalarResult<C>> for &AuthenticatedScalarResult<C> {
type Output = AuthenticatedScalarResult<C>;
fn add(self, rhs: &ScalarResult<C>) -> Self::Output {
let new_share = if self.fabric().party_id() == PARTY0 {
&self.share + rhs
} else {
&self.share + Scalar::zero()
};
let new_modifier = &self.public_modifier - rhs;
AuthenticatedScalarResult {
share: new_share,
mac: self.mac.clone(),
public_modifier: new_modifier,
}
}
}
impl_borrow_variants!(AuthenticatedScalarResult<C>, Add, add, +, ScalarResult<C>, Output=AuthenticatedScalarResult<C>, C: CurveGroup);
impl_commutative!(AuthenticatedScalarResult<C>, Add, add, +, ScalarResult<C>, Output=AuthenticatedScalarResult<C>, C: CurveGroup);
impl<C: CurveGroup> Add<&AuthenticatedScalarResult<C>> for &AuthenticatedScalarResult<C> {
type Output = AuthenticatedScalarResult<C>;
fn add(self, rhs: &AuthenticatedScalarResult<C>) -> Self::Output {
AuthenticatedScalarResult {
share: &self.share + &rhs.share,
mac: &self.mac + &rhs.mac,
public_modifier: self.public_modifier.clone() + rhs.public_modifier.clone(),
}
}
}
impl_borrow_variants!(AuthenticatedScalarResult<C>, Add, add, +, AuthenticatedScalarResult<C>, Output=AuthenticatedScalarResult<C>, C: CurveGroup);
impl<C: CurveGroup> AuthenticatedScalarResult<C> {
pub fn batch_add(
a: &[AuthenticatedScalarResult<C>],
b: &[AuthenticatedScalarResult<C>],
) -> Vec<AuthenticatedScalarResult<C>> {
assert_eq!(a.len(), b.len(), "Cannot add batches of different sizes");
let n = a.len();
let fabric = a[0].fabric();
let all_ids = a.iter().chain(b.iter()).flat_map(|v| v.ids()).collect_vec();
let gate_results: Vec<ScalarResult<C>> = fabric.new_batch_gate_op(
all_ids,
AUTHENTICATED_SCALAR_RESULT_LEN * n, move |mut args| {
let arg_len = args.len();
let a_vals = args.drain(..arg_len / 2).collect_vec();
let b_vals = args;
let mut result = Vec::with_capacity(AUTHENTICATED_SCALAR_RESULT_LEN * n);
for (mut a_vals, mut b_vals) in a_vals
.into_iter()
.chunks(AUTHENTICATED_SCALAR_RESULT_LEN)
.into_iter()
.zip(
b_vals
.into_iter()
.chunks(AUTHENTICATED_SCALAR_RESULT_LEN)
.into_iter(),
)
{
let a_share: Scalar<C> = a_vals.next().unwrap().into();
let a_mac_share: Scalar<C> = a_vals.next().unwrap().into();
let a_modifier: Scalar<C> = a_vals.next().unwrap().into();
let b_share: Scalar<C> = b_vals.next().unwrap().into();
let b_mac_share: Scalar<C> = b_vals.next().unwrap().into();
let b_modifier: Scalar<C> = b_vals.next().unwrap().into();
result.push(ResultValue::Scalar(a_share + b_share));
result.push(ResultValue::Scalar(a_mac_share + b_mac_share));
result.push(ResultValue::Scalar(a_modifier + b_modifier));
}
result
},
);
AuthenticatedScalarResult::from_flattened_iterator(gate_results.into_iter())
}
pub fn batch_add_public(
a: &[AuthenticatedScalarResult<C>],
b: &[ScalarResult<C>],
) -> Vec<AuthenticatedScalarResult<C>> {
assert_eq!(a.len(), b.len(), "Cannot add batches of different sizes");
let n = a.len();
let results_per_value = 3;
let fabric = a[0].fabric();
let all_ids = a
.iter()
.flat_map(|v| v.ids())
.chain(b.iter().map(|v| v.id()))
.collect_vec();
let party_id = fabric.party_id();
let gate_results: Vec<ScalarResult<C>> = fabric.new_batch_gate_op(
all_ids,
results_per_value * n, move |mut args| {
let a_vals = args
.drain(..AUTHENTICATED_SCALAR_RESULT_LEN * n)
.collect_vec();
let public_values = args;
let mut result = Vec::with_capacity(results_per_value * n);
for (mut a_vals, public_value) in a_vals
.into_iter()
.chunks(results_per_value)
.into_iter()
.zip(public_values.into_iter())
{
let a_share: Scalar<C> = a_vals.next().unwrap().into();
let a_mac_share: Scalar<C> = a_vals.next().unwrap().into();
let a_modifier: Scalar<C> = a_vals.next().unwrap().into();
let public_value: Scalar<C> = public_value.into();
if party_id == PARTY0 {
result.push(ResultValue::Scalar(a_share + public_value));
} else {
result.push(ResultValue::Scalar(a_share));
}
result.push(ResultValue::Scalar(a_mac_share));
result.push(ResultValue::Scalar(a_modifier - public_value));
}
result
},
);
AuthenticatedScalarResult::from_flattened_iterator(gate_results.into_iter())
}
}
impl<C: CurveGroup> Sum for AuthenticatedScalarResult<C> {
fn sum<I: Iterator<Item = Self>>(mut iter: I) -> Self {
let seed = iter.next().expect("Cannot sum empty iterator");
iter.fold(seed, |acc, val| acc + &val)
}
}
impl<C: CurveGroup> Sub<&Scalar<C>> for &AuthenticatedScalarResult<C> {
type Output = AuthenticatedScalarResult<C>;
fn sub(self, rhs: &Scalar<C>) -> Self::Output {
let new_share = &self.share - rhs;
let new_modifier = &self.public_modifier + rhs;
AuthenticatedScalarResult {
share: new_share,
mac: self.mac.clone(),
public_modifier: new_modifier,
}
}
}
impl_borrow_variants!(AuthenticatedScalarResult<C>, Sub, sub, -, Scalar<C>, Output=AuthenticatedScalarResult<C>, C: CurveGroup);
impl<C: CurveGroup> Sub<&AuthenticatedScalarResult<C>> for &Scalar<C> {
type Output = AuthenticatedScalarResult<C>;
fn sub(self, rhs: &AuthenticatedScalarResult<C>) -> Self::Output {
let new_share = self - &rhs.share;
let new_modifier = -self - &rhs.public_modifier;
AuthenticatedScalarResult {
share: new_share,
mac: -&rhs.mac,
public_modifier: new_modifier,
}
}
}
impl_borrow_variants!(Scalar<C>, Sub, sub, -, AuthenticatedScalarResult<C>, Output=AuthenticatedScalarResult<C>, C: CurveGroup);
impl<C: CurveGroup> Sub<&ScalarResult<C>> for &AuthenticatedScalarResult<C> {
type Output = AuthenticatedScalarResult<C>;
fn sub(self, rhs: &ScalarResult<C>) -> Self::Output {
let new_share = &self.share - rhs;
let new_modifier = &self.public_modifier + rhs;
AuthenticatedScalarResult {
share: new_share,
mac: self.mac.clone(),
public_modifier: new_modifier,
}
}
}
impl_borrow_variants!(AuthenticatedScalarResult<C>, Sub, sub, -, ScalarResult<C>, Output=AuthenticatedScalarResult<C>, C: CurveGroup);
impl<C: CurveGroup> Sub<&AuthenticatedScalarResult<C>> for &ScalarResult<C> {
type Output = AuthenticatedScalarResult<C>;
fn sub(self, rhs: &AuthenticatedScalarResult<C>) -> Self::Output {
let new_share = self - &rhs.share;
let new_modifier = -self - &rhs.public_modifier;
AuthenticatedScalarResult {
share: new_share,
mac: -&rhs.mac,
public_modifier: new_modifier,
}
}
}
impl_borrow_variants!(ScalarResult<C>, Sub, sub, -, AuthenticatedScalarResult<C>, Output=AuthenticatedScalarResult<C>, C: CurveGroup);
impl<C: CurveGroup> Sub<&AuthenticatedScalarResult<C>> for &AuthenticatedScalarResult<C> {
type Output = AuthenticatedScalarResult<C>;
fn sub(self, rhs: &AuthenticatedScalarResult<C>) -> Self::Output {
AuthenticatedScalarResult {
share: &self.share - &rhs.share,
mac: &self.mac - &rhs.mac,
public_modifier: self.public_modifier.clone() - rhs.public_modifier.clone(),
}
}
}
impl_borrow_variants!(AuthenticatedScalarResult<C>, Sub, sub, -, AuthenticatedScalarResult<C>, Output=AuthenticatedScalarResult<C>, C: CurveGroup);
impl<C: CurveGroup> AuthenticatedScalarResult<C> {
pub fn batch_sub(
a: &[AuthenticatedScalarResult<C>],
b: &[AuthenticatedScalarResult<C>],
) -> Vec<AuthenticatedScalarResult<C>> {
assert_eq!(a.len(), b.len(), "Cannot add batches of different sizes");
let n = a.len();
let fabric = &a[0].fabric();
let all_ids = a.iter().chain(b.iter()).flat_map(|v| v.ids()).collect_vec();
let gate_results: Vec<ScalarResult<C>> = fabric.new_batch_gate_op(
all_ids,
AUTHENTICATED_SCALAR_RESULT_LEN * n, move |mut args| {
let arg_len = args.len();
let a_vals = args.drain(..arg_len / 2).collect_vec();
let b_vals = args;
let mut result = Vec::with_capacity(AUTHENTICATED_SCALAR_RESULT_LEN * n);
for (mut a_vals, mut b_vals) in a_vals
.into_iter()
.chunks(AUTHENTICATED_SCALAR_RESULT_LEN)
.into_iter()
.zip(
b_vals
.into_iter()
.chunks(AUTHENTICATED_SCALAR_RESULT_LEN)
.into_iter(),
)
{
let a_share: Scalar<C> = a_vals.next().unwrap().into();
let a_mac_share: Scalar<C> = a_vals.next().unwrap().into();
let a_modifier: Scalar<C> = a_vals.next().unwrap().into();
let b_share: Scalar<C> = b_vals.next().unwrap().into();
let b_mac_share: Scalar<C> = b_vals.next().unwrap().into();
let b_modifier: Scalar<C> = b_vals.next().unwrap().into();
result.push(ResultValue::Scalar(a_share - b_share));
result.push(ResultValue::Scalar(a_mac_share - b_mac_share));
result.push(ResultValue::Scalar(a_modifier - b_modifier));
}
result
},
);
AuthenticatedScalarResult::from_flattened_iterator(gate_results.into_iter())
}
pub fn batch_sub_public(
a: &[AuthenticatedScalarResult<C>],
b: &[ScalarResult<C>],
) -> Vec<AuthenticatedScalarResult<C>> {
assert_eq!(a.len(), b.len(), "Cannot add batches of different sizes");
let n = a.len();
let results_per_value = 3;
let fabric = a[0].fabric();
let all_ids = a
.iter()
.flat_map(|v| v.ids())
.chain(b.iter().map(|v| v.id()))
.collect_vec();
let party_id = fabric.party_id();
let gate_results: Vec<ScalarResult<C>> = fabric.new_batch_gate_op(
all_ids,
results_per_value * n, move |mut args| {
let a_vals = args
.drain(..AUTHENTICATED_SCALAR_RESULT_LEN * n)
.collect_vec();
let public_values = args;
let mut result = Vec::with_capacity(results_per_value * n);
for (mut a_vals, public_value) in a_vals
.into_iter()
.chunks(results_per_value)
.into_iter()
.zip(public_values.into_iter())
{
let a_share: Scalar<C> = a_vals.next().unwrap().into();
let a_mac_share: Scalar<C> = a_vals.next().unwrap().into();
let a_modifier: Scalar<C> = a_vals.next().unwrap().into();
let public_value: Scalar<C> = public_value.into();
if party_id == PARTY0 {
result.push(ResultValue::Scalar(a_share - public_value));
} else {
result.push(ResultValue::Scalar(a_share));
}
result.push(ResultValue::Scalar(a_mac_share));
result.push(ResultValue::Scalar(a_modifier + public_value));
}
result
},
);
AuthenticatedScalarResult::from_flattened_iterator(gate_results.into_iter())
}
}
impl<C: CurveGroup> Neg for &AuthenticatedScalarResult<C> {
type Output = AuthenticatedScalarResult<C>;
fn neg(self) -> Self::Output {
AuthenticatedScalarResult {
share: -&self.share,
mac: -&self.mac,
public_modifier: -&self.public_modifier,
}
}
}
impl_borrow_variants!(AuthenticatedScalarResult<C>, Neg, neg, -, C: CurveGroup);
impl<C: CurveGroup> AuthenticatedScalarResult<C> {
pub fn batch_neg(a: &[AuthenticatedScalarResult<C>]) -> Vec<AuthenticatedScalarResult<C>> {
if a.is_empty() {
return vec![];
}
let n = a.len();
let fabric = a[0].fabric();
let all_ids = a.iter().flat_map(|v| v.ids()).collect_vec();
let scalars = fabric.new_batch_gate_op(
all_ids,
AUTHENTICATED_SCALAR_RESULT_LEN * n, |args| {
args.into_iter()
.map(|arg| ResultValue::Scalar(-Scalar::from(arg)))
.collect()
},
);
AuthenticatedScalarResult::from_flattened_iterator(scalars.into_iter())
}
}
impl<C: CurveGroup> Mul<&Scalar<C>> for &AuthenticatedScalarResult<C> {
type Output = AuthenticatedScalarResult<C>;
fn mul(self, rhs: &Scalar<C>) -> Self::Output {
AuthenticatedScalarResult {
share: &self.share * rhs,
mac: &self.mac * rhs,
public_modifier: &self.public_modifier * rhs,
}
}
}
impl_borrow_variants!(AuthenticatedScalarResult<C>, Mul, mul, *, Scalar<C>, Output=AuthenticatedScalarResult<C>, C: CurveGroup);
impl_commutative!(AuthenticatedScalarResult<C>, Mul, mul, *, Scalar<C>, Output=AuthenticatedScalarResult<C>, C: CurveGroup);
impl<C: CurveGroup> Mul<&ScalarResult<C>> for &AuthenticatedScalarResult<C> {
type Output = AuthenticatedScalarResult<C>;
fn mul(self, rhs: &ScalarResult<C>) -> Self::Output {
AuthenticatedScalarResult {
share: &self.share * rhs,
mac: &self.mac * rhs,
public_modifier: &self.public_modifier * rhs,
}
}
}
impl_borrow_variants!(AuthenticatedScalarResult<C>, Mul, mul, *, ScalarResult<C>, Output=AuthenticatedScalarResult<C>, C: CurveGroup);
impl_commutative!(AuthenticatedScalarResult<C>, Mul, mul, *, ScalarResult<C>, Output=AuthenticatedScalarResult<C>, C: CurveGroup);
impl<C: CurveGroup> Mul<&AuthenticatedScalarResult<C>> for &AuthenticatedScalarResult<C> {
type Output = AuthenticatedScalarResult<C>;
fn mul(self, rhs: &AuthenticatedScalarResult<C>) -> Self::Output {
let (a, b, c) = self.fabric().next_authenticated_triple();
let masked_lhs = self - &a;
let masked_rhs = rhs - &b;
let d = masked_lhs.open();
let e = masked_rhs.open();
&d * &e + d * b + e * a + c
}
}
impl_borrow_variants!(AuthenticatedScalarResult<C>, Mul, mul, *, AuthenticatedScalarResult<C>, Output=AuthenticatedScalarResult<C>, C: CurveGroup);
impl<C: CurveGroup> AuthenticatedScalarResult<C> {
pub fn batch_mul(
a: &[AuthenticatedScalarResult<C>],
b: &[AuthenticatedScalarResult<C>],
) -> Vec<AuthenticatedScalarResult<C>> {
assert_eq!(
a.len(),
b.len(),
"Cannot multiply batches of different sizes"
);
if a.is_empty() {
return vec![];
}
let n = a.len();
let fabric = a[0].fabric();
let (beaver_a, beaver_b, beaver_c) = fabric.next_authenticated_triple_batch(n);
let masked_lhs = AuthenticatedScalarResult::batch_sub(a, &beaver_a);
let masked_rhs = AuthenticatedScalarResult::batch_sub(b, &beaver_b);
let all_masks = [masked_lhs, masked_rhs].concat();
let opened_values = AuthenticatedScalarResult::open_batch(&all_masks);
let (d_open, e_open) = opened_values.split_at(n);
let de = ScalarResult::batch_mul(d_open, e_open);
let db = AuthenticatedScalarResult::batch_mul_public(&beaver_b, d_open);
let ea = AuthenticatedScalarResult::batch_mul_public(&beaver_a, e_open);
let de_plus_db = AuthenticatedScalarResult::batch_add_public(&db, &de);
let ea_plus_c = AuthenticatedScalarResult::batch_add(&ea, &beaver_c);
AuthenticatedScalarResult::batch_add(&de_plus_db, &ea_plus_c)
}
pub fn batch_mul_public(
a: &[AuthenticatedScalarResult<C>],
b: &[ScalarResult<C>],
) -> Vec<AuthenticatedScalarResult<C>> {
assert_eq!(
a.len(),
b.len(),
"Cannot multiply batches of different sizes"
);
if a.is_empty() {
return vec![];
}
let n = a.len();
let fabric = a[0].fabric();
let all_ids = a
.iter()
.flat_map(|a| a.ids())
.chain(b.iter().map(|b| b.id()))
.collect_vec();
let scalars = fabric.new_batch_gate_op(
all_ids,
AUTHENTICATED_SCALAR_RESULT_LEN * n, move |mut args| {
let a_vals = args
.drain(..AUTHENTICATED_SCALAR_RESULT_LEN * n)
.collect_vec();
let public_values = args;
let mut result = Vec::with_capacity(AUTHENTICATED_SCALAR_RESULT_LEN * n);
for (a_vals, public_values) in a_vals
.chunks(AUTHENTICATED_SCALAR_RESULT_LEN)
.zip(public_values.into_iter())
{
let a_share: Scalar<C> = a_vals[0].to_owned().into();
let a_mac_share: Scalar<C> = a_vals[1].to_owned().into();
let a_modifier: Scalar<C> = a_vals[2].to_owned().into();
let public_value: Scalar<C> = public_values.into();
result.push(ResultValue::Scalar(a_share * public_value));
result.push(ResultValue::Scalar(a_mac_share * public_value));
result.push(ResultValue::Scalar(a_modifier * public_value));
}
result
},
);
AuthenticatedScalarResult::from_flattened_iterator(scalars.into_iter())
}
}
impl<C: CurveGroup> Mul<&AuthenticatedScalarResult<C>> for &CurvePoint<C> {
type Output = AuthenticatedPointResult<C>;
fn mul(self, rhs: &AuthenticatedScalarResult<C>) -> Self::Output {
AuthenticatedPointResult {
share: self * &rhs.share,
mac: self * &rhs.mac,
public_modifier: self * &rhs.public_modifier,
}
}
}
impl_commutative!(CurvePoint<C>, Mul, mul, *, AuthenticatedScalarResult<C>, Output=AuthenticatedPointResult<C>, C: CurveGroup);
impl<C: CurveGroup> Mul<&AuthenticatedScalarResult<C>> for &CurvePointResult<C> {
type Output = AuthenticatedPointResult<C>;
fn mul(self, rhs: &AuthenticatedScalarResult<C>) -> Self::Output {
AuthenticatedPointResult {
share: self * &rhs.share,
mac: self * &rhs.mac,
public_modifier: self * &rhs.public_modifier,
}
}
}
impl_borrow_variants!(CurvePointResult<C>, Mul, mul, *, AuthenticatedScalarResult<C>, Output=AuthenticatedPointResult<C>, C: CurveGroup);
impl_commutative!(CurvePointResult<C>, Mul, mul, *, AuthenticatedScalarResult<C>, Output=AuthenticatedPointResult<C>, C: CurveGroup);
#[cfg(feature = "test_helpers")]
pub mod test_helpers {
use ark_ec::CurveGroup;
use crate::algebra::scalar::Scalar;
use super::AuthenticatedScalarResult;
pub fn modify_mac<C: CurveGroup>(val: &mut AuthenticatedScalarResult<C>, new_value: Scalar<C>) {
val.mac = val.fabric().allocate_scalar(new_value).into()
}
pub fn modify_share<C: CurveGroup>(
val: &mut AuthenticatedScalarResult<C>,
new_value: Scalar<C>,
) {
val.share = val.fabric().allocate_scalar(new_value).into()
}
pub fn modify_public_modifier<C: CurveGroup>(
val: &mut AuthenticatedScalarResult<C>,
new_value: Scalar<C>,
) {
val.public_modifier = val.fabric().allocate_scalar(new_value)
}
}
#[cfg(test)]
mod tests {
use rand::thread_rng;
use crate::{algebra::scalar::Scalar, test_helpers::execute_mock_mpc, PARTY0};
#[tokio::test]
async fn test_sub() {
let mut rng = thread_rng();
let value1 = Scalar::random(&mut rng);
let value2 = Scalar::random(&mut rng);
let (res, _) = execute_mock_mpc(|fabric| async move {
let party0_value = fabric.share_scalar(value1, PARTY0);
let public_value = fabric.allocate_scalar(value2);
let res1 = &party0_value - &public_value;
let res_open1 = res1.open_authenticated().await.unwrap();
let expected1 = value1 - value2;
let res2 = &public_value - &party0_value;
let res_open2 = res2.open_authenticated().await.unwrap();
let expected2 = value2 - value1;
(res_open1 == expected1, res_open2 == expected2)
})
.await;
assert!(res.0);
assert!(res.1)
}
#[tokio::test]
async fn test_sub_constant() {
let mut rng = thread_rng();
let value1 = Scalar::random(&mut rng);
let value2 = Scalar::random(&mut rng);
let (res, _) = execute_mock_mpc(|fabric| async move {
let party0_value = fabric.share_scalar(value1, PARTY0);
let res1 = &party0_value - value2;
let res_open1 = res1.open_authenticated().await.unwrap();
let expected1 = value1 - value2;
let res2 = value2 - &party0_value;
let res_open2 = res2.open_authenticated().await.unwrap();
let expected2 = value2 - value1;
(res_open1 == expected1, res_open2 == expected2)
})
.await;
assert!(res.0);
assert!(res.1)
}
#[tokio::test]
async fn test_xor_circuit() {
let (res, _) = execute_mock_mpc(|fabric| async move {
let a = &fabric.zero_authenticated();
let b = &fabric.zero_authenticated();
let res = a + b - Scalar::from(2u64) * a * b;
res.open_authenticated().await
})
.await;
assert_eq!(res.unwrap(), 0u8.into());
}
}