use std::{
fmt::Debug,
iter::Sum,
ops::{Add, Mul, Neg, Sub},
pin::Pin,
task::{Context, Poll},
};
use futures::{Future, FutureExt};
use itertools::{izip, Itertools};
use crate::{
commitment::{PedersenCommitment, PedersenCommitmentResult},
error::MpcError,
fabric::{MpcFabric, ResultId, ResultValue},
ResultHandle, PARTY0,
};
use super::{
authenticated_stark_point::AuthenticatedStarkPointResult,
macros::{impl_borrow_variants, impl_commutative},
mpc_scalar::MpcScalarResult,
scalar::{BatchScalarResult, Scalar, ScalarResult},
stark_curve::{StarkPoint, StarkPointResult},
};
pub const AUTHENTICATED_SCALAR_RESULT_LEN: usize = 3;
#[derive(Clone)]
pub struct AuthenticatedScalarResult {
pub(crate) share: MpcScalarResult,
pub(crate) mac: MpcScalarResult,
pub(crate) public_modifier: ScalarResult,
}
impl Debug for AuthenticatedScalarResult {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AuthenticatedScalarResult")
.field("value", &self.share.id())
.field("mac", &self.mac.id())
.field("public_modifier", &self.public_modifier.id)
.finish()
}
}
impl AuthenticatedScalarResult {
pub fn new_shared(value: ScalarResult) -> Self {
let fabric = value.fabric.clone();
let mpc_value = MpcScalarResult::new_shared(value);
let mac = fabric.borrow_mac_key() * mpc_value.clone();
let public_modifier = fabric.zero();
Self {
share: mpc_value,
mac,
public_modifier,
}
}
pub fn new_shared_batch(values: &[ScalarResult]) -> Vec<Self> {
if values.is_empty() {
return vec![];
}
let n = values.len();
let fabric = values[0].fabric();
let mpc_values = values
.iter()
.map(|v| MpcScalarResult::new_shared(v.clone()))
.collect_vec();
let mac_keys = (0..n)
.map(|_| fabric.borrow_mac_key().clone())
.collect_vec();
let values_macs = MpcScalarResult::batch_mul(&mpc_values, &mac_keys);
mpc_values
.into_iter()
.zip(values_macs.into_iter())
.map(|(value, mac)| Self {
share: value,
mac,
public_modifier: fabric.zero(),
})
.collect_vec()
}
pub fn new_shared_from_batch_result(
values: BatchScalarResult,
n: usize,
) -> Vec<AuthenticatedScalarResult> {
let scalar_results = values
.fabric()
.new_batch_gate_op(vec![values.id()], n, |mut args| {
let scalars: Vec<Scalar> = args.pop().unwrap().into();
scalars.into_iter().map(ResultValue::Scalar).collect()
});
Self::new_shared_batch(&scalar_results)
}
#[cfg(feature = "test_helpers")]
pub fn mpc_share(&self) -> MpcScalarResult {
self.share.clone()
}
pub fn share(&self) -> ScalarResult {
self.share.to_scalar()
}
pub fn fabric(&self) -> &MpcFabric {
self.share.fabric()
}
pub fn ids(&self) -> Vec<ResultId> {
vec![self.share.id(), self.mac.id(), self.public_modifier.id]
}
pub fn open(&self) -> ScalarResult {
self.share.open()
}
pub fn open_batch(values: &[Self]) -> Vec<ScalarResult> {
MpcScalarResult::open_batch(&values.iter().map(|val| val.share.clone()).collect_vec())
}
pub fn from_flattened_iterator<I>(iter: I) -> Vec<Self>
where
I: Iterator<Item = ResultHandle<Scalar>>,
{
iter.chunks(AUTHENTICATED_SCALAR_RESULT_LEN)
.into_iter()
.map(|mut chunk| Self {
share: chunk.next().unwrap().into(),
mac: chunk.next().unwrap().into(),
public_modifier: chunk.next().unwrap(),
})
.collect_vec()
}
pub fn verify_mac_check(
my_mac_share: Scalar,
peer_mac_share: Scalar,
peer_mac_commitment: StarkPoint,
peer_commitment_blinder: Scalar,
) -> bool {
let their_comm = PedersenCommitment {
value: peer_mac_share,
blinder: peer_commitment_blinder,
commitment: peer_mac_commitment,
};
if !their_comm.verify() {
return false;
}
if peer_mac_share + my_mac_share != Scalar::from(0) {
return false;
}
true
}
pub fn open_authenticated(&self) -> AuthenticatedScalarOpenResult {
let recovered_value = self.share.open();
let mac_check_value: ScalarResult = self.fabric().new_gate_op(
vec![
self.fabric().borrow_mac_key().id(),
recovered_value.id,
self.public_modifier.id,
self.mac.id(),
],
move |mut args| {
let mac_key_share: Scalar = args.remove(0).into();
let value: Scalar = args.remove(0).into();
let modifier: Scalar = args.remove(0).into();
let mac_share: Scalar = args.remove(0).into();
ResultValue::Scalar(mac_key_share * (value + modifier) - mac_share)
},
);
let my_comm = PedersenCommitmentResult::commit(mac_check_value);
let peer_commit = self.fabric().exchange_value(my_comm.commitment);
let peer_mac_check = self.fabric().exchange_value(my_comm.value.clone());
let blinder_result: ScalarResult = self.fabric().allocate_scalar(my_comm.blinder);
let peer_blinder = self.fabric().exchange_value(blinder_result);
let commitment_check: ScalarResult = self.fabric().new_gate_op(
vec![
my_comm.value.id,
peer_mac_check.id,
peer_blinder.id,
peer_commit.id,
],
|mut args| {
let my_comm_value: Scalar = args.remove(0).into();
let peer_value: Scalar = args.remove(0).into();
let blinder: Scalar = args.remove(0).into();
let commitment: StarkPoint = args.remove(0).into();
ResultValue::Scalar(Scalar::from(Self::verify_mac_check(
my_comm_value,
peer_value,
commitment,
blinder,
)))
},
);
AuthenticatedScalarOpenResult {
value: recovered_value,
mac_check: commitment_check,
}
}
pub fn open_authenticated_batch(values: &[Self]) -> Vec<AuthenticatedScalarOpenResult> {
if values.is_empty() {
return vec![];
}
let n = values.len();
let fabric = &values[0].fabric();
let values_open = Self::open_batch(values);
let mut mac_check_deps = Vec::with_capacity(1 + 3 * n);
mac_check_deps.push(fabric.borrow_mac_key().id());
for i in 0..n {
mac_check_deps.push(values_open[i].id());
mac_check_deps.push(values[i].public_modifier.id());
mac_check_deps.push(values[i].mac.id());
}
let mac_checks: Vec<ScalarResult> =
fabric.new_batch_gate_op(mac_check_deps, n , move |mut args| {
let mac_key_share: Scalar = args.remove(0).into();
let mut check_result = Vec::with_capacity(n);
for _ in 0..n {
let value: Scalar = args.remove(0).into();
let modifier: Scalar = args.remove(0).into();
let mac_share: Scalar = args.remove(0).into();
check_result.push(mac_key_share * (value + modifier) - mac_share);
}
check_result.into_iter().map(ResultValue::Scalar).collect()
});
let my_comms = mac_checks
.iter()
.cloned()
.map(PedersenCommitmentResult::commit)
.collect_vec();
let peer_comms = fabric.exchange_values(
&my_comms
.iter()
.map(|comm| comm.commitment.clone())
.collect_vec(),
);
let peer_mac_checks = fabric.exchange_values(&mac_checks);
let peer_blinders = fabric.exchange_values(
&my_comms
.iter()
.map(|comm| fabric.allocate_scalar(comm.blinder))
.collect_vec(),
);
let mut mac_check_gate_deps = my_comms.iter().map(|comm| comm.value.id).collect_vec();
mac_check_gate_deps.push(peer_mac_checks.id);
mac_check_gate_deps.push(peer_blinders.id);
mac_check_gate_deps.push(peer_comms.id);
let commitment_checks: Vec<ScalarResult> = fabric.new_batch_gate_op(
mac_check_gate_deps,
n,
move |mut args| {
let my_comms: Vec<Scalar> = args.drain(..n).map(|comm| comm.into()).collect();
let peer_mac_checks: Vec<Scalar> = args.remove(0).into();
let peer_blinders: Vec<Scalar> = args.remove(0).into();
let peer_comms: Vec<StarkPoint> = args.remove(0).into();
let mut mac_checks = Vec::with_capacity(n);
for (my_mac_share, peer_mac_share, peer_blinder, peer_commitment) in izip!(
my_comms.into_iter(),
peer_mac_checks.into_iter(),
peer_blinders.into_iter(),
peer_comms.into_iter()
) {
let mac_check = Self::verify_mac_check(
my_mac_share,
peer_mac_share,
peer_commitment,
peer_blinder,
);
mac_checks.push(ResultValue::Scalar(Scalar::from(mac_check)));
}
mac_checks
},
);
values_open
.into_iter()
.zip(commitment_checks.into_iter())
.map(|(value, check)| AuthenticatedScalarOpenResult {
value,
mac_check: check,
})
.collect_vec()
}
}
#[derive(Clone)]
pub struct AuthenticatedScalarOpenResult {
pub value: ScalarResult,
pub mac_check: ScalarResult,
}
impl Future for AuthenticatedScalarOpenResult {
type Output = Result<Scalar, MpcError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let value = futures::ready!(self.as_mut().value.poll_unpin(cx));
let mac_check = futures::ready!(self.as_mut().mac_check.poll_unpin(cx));
if mac_check == Scalar::from(1) {
Poll::Ready(Ok(value))
} else {
Poll::Ready(Err(MpcError::AuthenticationError))
}
}
}
impl Add<&Scalar> for &AuthenticatedScalarResult {
type Output = AuthenticatedScalarResult;
fn add(self, rhs: &Scalar) -> Self::Output {
let new_share = if self.fabric().party_id() == PARTY0 {
&self.share + rhs
} else {
&self.share + Scalar::from(0)
};
let new_modifier = &self.public_modifier - rhs;
AuthenticatedScalarResult {
share: new_share,
mac: self.mac.clone(),
public_modifier: new_modifier,
}
}
}
impl_borrow_variants!(AuthenticatedScalarResult, Add, add, +, Scalar, Output=AuthenticatedScalarResult);
impl_commutative!(AuthenticatedScalarResult, Add, add, +, Scalar, Output=AuthenticatedScalarResult);
impl Add<&ScalarResult> for &AuthenticatedScalarResult {
type Output = AuthenticatedScalarResult;
fn add(self, rhs: &ScalarResult) -> Self::Output {
let new_share = if self.fabric().party_id() == PARTY0 {
&self.share + rhs
} else {
&self.share + Scalar::from(0)
};
let new_modifier = &self.public_modifier - rhs;
AuthenticatedScalarResult {
share: new_share,
mac: self.mac.clone(),
public_modifier: new_modifier,
}
}
}
impl_borrow_variants!(AuthenticatedScalarResult, Add, add, +, ScalarResult, Output=AuthenticatedScalarResult);
impl_commutative!(AuthenticatedScalarResult, Add, add, +, ScalarResult, Output=AuthenticatedScalarResult);
impl Add<&AuthenticatedScalarResult> for &AuthenticatedScalarResult {
type Output = AuthenticatedScalarResult;
fn add(self, rhs: &AuthenticatedScalarResult) -> Self::Output {
AuthenticatedScalarResult {
share: &self.share + &rhs.share,
mac: &self.mac + &rhs.mac,
public_modifier: self.public_modifier.clone() + rhs.public_modifier.clone(),
}
}
}
impl_borrow_variants!(AuthenticatedScalarResult, Add, add, +, AuthenticatedScalarResult, Output=AuthenticatedScalarResult);
impl AuthenticatedScalarResult {
pub fn batch_add(
a: &[AuthenticatedScalarResult],
b: &[AuthenticatedScalarResult],
) -> Vec<AuthenticatedScalarResult> {
assert_eq!(a.len(), b.len(), "Cannot add batches of different sizes");
let n = a.len();
let fabric = a[0].fabric();
let all_ids = a.iter().chain(b.iter()).flat_map(|v| v.ids()).collect_vec();
let gate_results: Vec<ScalarResult> = fabric.new_batch_gate_op(
all_ids,
AUTHENTICATED_SCALAR_RESULT_LEN * n,
move |mut args| {
let arg_len = args.len();
let a_vals = args.drain(..arg_len / 2).collect_vec();
let b_vals = args;
let mut result = Vec::with_capacity(AUTHENTICATED_SCALAR_RESULT_LEN * n);
for (mut a_vals, mut b_vals) in a_vals
.into_iter()
.chunks(AUTHENTICATED_SCALAR_RESULT_LEN)
.into_iter()
.zip(
b_vals
.into_iter()
.chunks(AUTHENTICATED_SCALAR_RESULT_LEN)
.into_iter(),
)
{
let a_share: Scalar = a_vals.next().unwrap().into();
let a_mac_share: Scalar = a_vals.next().unwrap().into();
let a_modifier: Scalar = a_vals.next().unwrap().into();
let b_share: Scalar = b_vals.next().unwrap().into();
let b_mac_share: Scalar = b_vals.next().unwrap().into();
let b_modifier: Scalar = b_vals.next().unwrap().into();
result.push(ResultValue::Scalar(a_share + b_share));
result.push(ResultValue::Scalar(a_mac_share + b_mac_share));
result.push(ResultValue::Scalar(a_modifier + b_modifier));
}
result
},
);
AuthenticatedScalarResult::from_flattened_iterator(gate_results.into_iter())
}
pub fn batch_add_public(
a: &[AuthenticatedScalarResult],
b: &[ScalarResult],
) -> Vec<AuthenticatedScalarResult> {
assert_eq!(a.len(), b.len(), "Cannot add batches of different sizes");
let n = a.len();
let results_per_value = 3;
let fabric = a[0].fabric();
let all_ids = a
.iter()
.flat_map(|v| v.ids())
.chain(b.iter().map(|v| v.id()))
.collect_vec();
let party_id = fabric.party_id();
let gate_results: Vec<ScalarResult> = fabric.new_batch_gate_op(
all_ids,
results_per_value * n,
move |mut args| {
let a_vals = args
.drain(..AUTHENTICATED_SCALAR_RESULT_LEN * n)
.collect_vec();
let public_values = args;
let mut result = Vec::with_capacity(results_per_value * n);
for (mut a_vals, public_value) in a_vals
.into_iter()
.chunks(results_per_value)
.into_iter()
.zip(public_values.into_iter())
{
let a_share: Scalar = a_vals.next().unwrap().into();
let a_mac_share: Scalar = a_vals.next().unwrap().into();
let a_modifier: Scalar = a_vals.next().unwrap().into();
let public_value: Scalar = public_value.into();
if party_id == PARTY0 {
result.push(ResultValue::Scalar(a_share + public_value));
} else {
result.push(ResultValue::Scalar(a_share));
}
result.push(ResultValue::Scalar(a_mac_share));
result.push(ResultValue::Scalar(a_modifier - public_value));
}
result
},
);
AuthenticatedScalarResult::from_flattened_iterator(gate_results.into_iter())
}
}
impl Sum for AuthenticatedScalarResult {
fn sum<I: Iterator<Item = Self>>(mut iter: I) -> Self {
let seed = iter.next().expect("Cannot sum empty iterator");
iter.fold(seed, |acc, val| acc + &val)
}
}
impl Sub<&Scalar> for &AuthenticatedScalarResult {
type Output = AuthenticatedScalarResult;
fn sub(self, rhs: &Scalar) -> Self::Output {
let new_share = &self.share - rhs;
let new_modifier = &self.public_modifier + rhs;
AuthenticatedScalarResult {
share: new_share,
mac: self.mac.clone(),
public_modifier: new_modifier,
}
}
}
impl_borrow_variants!(AuthenticatedScalarResult, Sub, sub, -, Scalar, Output=AuthenticatedScalarResult);
impl Sub<&AuthenticatedScalarResult> for &Scalar {
type Output = AuthenticatedScalarResult;
fn sub(self, rhs: &AuthenticatedScalarResult) -> Self::Output {
let new_share = self - &rhs.share;
let new_modifier = -self - &rhs.public_modifier;
AuthenticatedScalarResult {
share: new_share,
mac: -&rhs.mac,
public_modifier: new_modifier,
}
}
}
impl_borrow_variants!(Scalar, Sub, sub, -, AuthenticatedScalarResult, Output=AuthenticatedScalarResult);
impl Sub<&ScalarResult> for &AuthenticatedScalarResult {
type Output = AuthenticatedScalarResult;
fn sub(self, rhs: &ScalarResult) -> Self::Output {
let new_share = &self.share - rhs;
let new_modifier = &self.public_modifier + rhs;
AuthenticatedScalarResult {
share: new_share,
mac: self.mac.clone(),
public_modifier: new_modifier,
}
}
}
impl_borrow_variants!(AuthenticatedScalarResult, Sub, sub, -, ScalarResult, Output=AuthenticatedScalarResult);
impl Sub<&AuthenticatedScalarResult> for &ScalarResult {
type Output = AuthenticatedScalarResult;
fn sub(self, rhs: &AuthenticatedScalarResult) -> Self::Output {
let new_share = self - &rhs.share;
let new_modifier = -self - &rhs.public_modifier;
AuthenticatedScalarResult {
share: new_share,
mac: -&rhs.mac,
public_modifier: new_modifier,
}
}
}
impl_borrow_variants!(ScalarResult, Sub, sub, -, AuthenticatedScalarResult, Output=AuthenticatedScalarResult);
impl Sub<&AuthenticatedScalarResult> for &AuthenticatedScalarResult {
type Output = AuthenticatedScalarResult;
fn sub(self, rhs: &AuthenticatedScalarResult) -> Self::Output {
AuthenticatedScalarResult {
share: &self.share - &rhs.share,
mac: &self.mac - &rhs.mac,
public_modifier: self.public_modifier.clone() - rhs.public_modifier.clone(),
}
}
}
impl_borrow_variants!(AuthenticatedScalarResult, Sub, sub, -, AuthenticatedScalarResult, Output=AuthenticatedScalarResult);
impl AuthenticatedScalarResult {
pub fn batch_sub(
a: &[AuthenticatedScalarResult],
b: &[AuthenticatedScalarResult],
) -> Vec<AuthenticatedScalarResult> {
assert_eq!(a.len(), b.len(), "Cannot add batches of different sizes");
let n = a.len();
let fabric = &a[0].fabric();
let all_ids = a.iter().chain(b.iter()).flat_map(|v| v.ids()).collect_vec();
let gate_results: Vec<ScalarResult> = fabric.new_batch_gate_op(
all_ids,
AUTHENTICATED_SCALAR_RESULT_LEN * n,
move |mut args| {
let arg_len = args.len();
let a_vals = args.drain(..arg_len / 2).collect_vec();
let b_vals = args;
let mut result = Vec::with_capacity(AUTHENTICATED_SCALAR_RESULT_LEN * n);
for (mut a_vals, mut b_vals) in a_vals
.into_iter()
.chunks(AUTHENTICATED_SCALAR_RESULT_LEN)
.into_iter()
.zip(
b_vals
.into_iter()
.chunks(AUTHENTICATED_SCALAR_RESULT_LEN)
.into_iter(),
)
{
let a_share: Scalar = a_vals.next().unwrap().into();
let a_mac_share: Scalar = a_vals.next().unwrap().into();
let a_modifier: Scalar = a_vals.next().unwrap().into();
let b_share: Scalar = b_vals.next().unwrap().into();
let b_mac_share: Scalar = b_vals.next().unwrap().into();
let b_modifier: Scalar = b_vals.next().unwrap().into();
result.push(ResultValue::Scalar(a_share - b_share));
result.push(ResultValue::Scalar(a_mac_share - b_mac_share));
result.push(ResultValue::Scalar(a_modifier - b_modifier));
}
result
},
);
AuthenticatedScalarResult::from_flattened_iterator(gate_results.into_iter())
}
pub fn batch_sub_public(
a: &[AuthenticatedScalarResult],
b: &[ScalarResult],
) -> Vec<AuthenticatedScalarResult> {
assert_eq!(a.len(), b.len(), "Cannot add batches of different sizes");
let n = a.len();
let results_per_value = 3;
let fabric = a[0].fabric();
let all_ids = a
.iter()
.flat_map(|v| v.ids())
.chain(b.iter().map(|v| v.id()))
.collect_vec();
let party_id = fabric.party_id();
let gate_results: Vec<ScalarResult> = fabric.new_batch_gate_op(
all_ids,
results_per_value * n,
move |mut args| {
let a_vals = args
.drain(..AUTHENTICATED_SCALAR_RESULT_LEN * n)
.collect_vec();
let public_values = args;
let mut result = Vec::with_capacity(results_per_value * n);
for (mut a_vals, public_value) in a_vals
.into_iter()
.chunks(results_per_value)
.into_iter()
.zip(public_values.into_iter())
{
let a_share: Scalar = a_vals.next().unwrap().into();
let a_mac_share: Scalar = a_vals.next().unwrap().into();
let a_modifier: Scalar = a_vals.next().unwrap().into();
let public_value: Scalar = public_value.into();
if party_id == PARTY0 {
result.push(ResultValue::Scalar(a_share - public_value));
} else {
result.push(ResultValue::Scalar(a_share));
}
result.push(ResultValue::Scalar(a_mac_share));
result.push(ResultValue::Scalar(a_modifier + public_value));
}
result
},
);
AuthenticatedScalarResult::from_flattened_iterator(gate_results.into_iter())
}
}
impl Neg for &AuthenticatedScalarResult {
type Output = AuthenticatedScalarResult;
fn neg(self) -> Self::Output {
AuthenticatedScalarResult {
share: -&self.share,
mac: -&self.mac,
public_modifier: -&self.public_modifier,
}
}
}
impl_borrow_variants!(AuthenticatedScalarResult, Neg, neg, -);
impl AuthenticatedScalarResult {
pub fn batch_neg(a: &[AuthenticatedScalarResult]) -> Vec<AuthenticatedScalarResult> {
if a.is_empty() {
return vec![];
}
let n = a.len();
let fabric = a[0].fabric();
let all_ids = a.iter().flat_map(|v| v.ids()).collect_vec();
let scalars = fabric.new_batch_gate_op(
all_ids,
AUTHENTICATED_SCALAR_RESULT_LEN * n,
|args| {
args.into_iter()
.map(|arg| ResultValue::Scalar(-Scalar::from(arg)))
.collect()
},
);
AuthenticatedScalarResult::from_flattened_iterator(scalars.into_iter())
}
}
impl Mul<&Scalar> for &AuthenticatedScalarResult {
type Output = AuthenticatedScalarResult;
fn mul(self, rhs: &Scalar) -> Self::Output {
AuthenticatedScalarResult {
share: &self.share * rhs,
mac: &self.mac * rhs,
public_modifier: &self.public_modifier * rhs,
}
}
}
impl_borrow_variants!(AuthenticatedScalarResult, Mul, mul, *, Scalar, Output=AuthenticatedScalarResult);
impl_commutative!(AuthenticatedScalarResult, Mul, mul, *, Scalar, Output=AuthenticatedScalarResult);
impl Mul<&ScalarResult> for &AuthenticatedScalarResult {
type Output = AuthenticatedScalarResult;
fn mul(self, rhs: &ScalarResult) -> Self::Output {
AuthenticatedScalarResult {
share: &self.share * rhs,
mac: &self.mac * rhs,
public_modifier: &self.public_modifier * rhs,
}
}
}
impl_borrow_variants!(AuthenticatedScalarResult, Mul, mul, *, ScalarResult, Output=AuthenticatedScalarResult);
impl_commutative!(AuthenticatedScalarResult, Mul, mul, *, ScalarResult, Output=AuthenticatedScalarResult);
impl Mul<&AuthenticatedScalarResult> for &AuthenticatedScalarResult {
type Output = AuthenticatedScalarResult;
fn mul(self, rhs: &AuthenticatedScalarResult) -> Self::Output {
let (a, b, c) = self.fabric().next_authenticated_triple();
let masked_lhs = self - &a;
let masked_rhs = rhs - &b;
let d = masked_lhs.open();
let e = masked_rhs.open();
&d * &e + d * b + e * a + c
}
}
impl_borrow_variants!(AuthenticatedScalarResult, Mul, mul, *, AuthenticatedScalarResult, Output=AuthenticatedScalarResult);
impl AuthenticatedScalarResult {
pub fn batch_mul(
a: &[AuthenticatedScalarResult],
b: &[AuthenticatedScalarResult],
) -> Vec<AuthenticatedScalarResult> {
assert_eq!(
a.len(),
b.len(),
"Cannot multiply batches of different sizes"
);
if a.is_empty() {
return vec![];
}
let n = a.len();
let fabric = a[0].fabric();
let (beaver_a, beaver_b, beaver_c) = fabric.next_authenticated_triple_batch(n);
let masked_lhs = AuthenticatedScalarResult::batch_sub(a, &beaver_a);
let masked_rhs = AuthenticatedScalarResult::batch_sub(b, &beaver_b);
let all_masks = [masked_lhs, masked_rhs].concat();
let opened_values = AuthenticatedScalarResult::open_batch(&all_masks);
let (d_open, e_open) = opened_values.split_at(n);
let de = ScalarResult::batch_mul(d_open, e_open);
let db = AuthenticatedScalarResult::batch_mul_public(&beaver_b, d_open);
let ea = AuthenticatedScalarResult::batch_mul_public(&beaver_a, e_open);
let de_plus_db = AuthenticatedScalarResult::batch_add_public(&db, &de);
let ea_plus_c = AuthenticatedScalarResult::batch_add(&ea, &beaver_c);
AuthenticatedScalarResult::batch_add(&de_plus_db, &ea_plus_c)
}
pub fn batch_mul_public(
a: &[AuthenticatedScalarResult],
b: &[ScalarResult],
) -> Vec<AuthenticatedScalarResult> {
assert_eq!(
a.len(),
b.len(),
"Cannot multiply batches of different sizes"
);
if a.is_empty() {
return vec![];
}
let n = a.len();
let fabric = a[0].fabric();
let all_ids = a
.iter()
.flat_map(|a| a.ids())
.chain(b.iter().map(|b| b.id()))
.collect_vec();
let scalars = fabric.new_batch_gate_op(
all_ids,
AUTHENTICATED_SCALAR_RESULT_LEN * n,
move |mut args| {
let a_vals = args
.drain(..AUTHENTICATED_SCALAR_RESULT_LEN * n)
.collect_vec();
let public_values = args;
let mut result = Vec::with_capacity(AUTHENTICATED_SCALAR_RESULT_LEN * n);
for (a_vals, public_values) in a_vals
.chunks(AUTHENTICATED_SCALAR_RESULT_LEN)
.zip(public_values.into_iter())
{
let a_share: Scalar = a_vals[0].to_owned().into();
let a_mac_share: Scalar = a_vals[1].to_owned().into();
let a_modifier: Scalar = a_vals[2].to_owned().into();
let public_value: Scalar = public_values.into();
result.push(ResultValue::Scalar(a_share * public_value));
result.push(ResultValue::Scalar(a_mac_share * public_value));
result.push(ResultValue::Scalar(a_modifier * public_value));
}
result
},
);
AuthenticatedScalarResult::from_flattened_iterator(scalars.into_iter())
}
}
impl Mul<&AuthenticatedScalarResult> for &StarkPoint {
type Output = AuthenticatedStarkPointResult;
fn mul(self, rhs: &AuthenticatedScalarResult) -> Self::Output {
AuthenticatedStarkPointResult {
share: self * &rhs.share,
mac: self * &rhs.mac,
public_modifier: self * &rhs.public_modifier,
}
}
}
impl_commutative!(StarkPoint, Mul, mul, *, AuthenticatedScalarResult, Output=AuthenticatedStarkPointResult);
impl Mul<&AuthenticatedScalarResult> for &StarkPointResult {
type Output = AuthenticatedStarkPointResult;
fn mul(self, rhs: &AuthenticatedScalarResult) -> Self::Output {
AuthenticatedStarkPointResult {
share: self * &rhs.share,
mac: self * &rhs.mac,
public_modifier: self * &rhs.public_modifier,
}
}
}
impl_borrow_variants!(StarkPointResult, Mul, mul, *, AuthenticatedScalarResult, Output=AuthenticatedStarkPointResult);
impl_commutative!(StarkPointResult, Mul, mul, *, AuthenticatedScalarResult, Output=AuthenticatedStarkPointResult);
#[cfg(feature = "test_helpers")]
pub mod test_helpers {
use crate::algebra::scalar::Scalar;
use super::AuthenticatedScalarResult;
pub fn modify_mac(val: &mut AuthenticatedScalarResult, new_value: Scalar) {
val.mac = val.fabric().allocate_scalar(new_value).into()
}
pub fn modify_share(val: &mut AuthenticatedScalarResult, new_value: Scalar) {
val.share = val.fabric().allocate_scalar(new_value).into()
}
pub fn modify_public_modifier(val: &mut AuthenticatedScalarResult, new_value: Scalar) {
val.public_modifier = val.fabric().allocate_scalar(new_value)
}
}
#[cfg(test)]
mod tests {
use rand::thread_rng;
use crate::{algebra::scalar::Scalar, test_helpers::execute_mock_mpc, PARTY0};
#[tokio::test]
async fn test_sub() {
let mut rng = thread_rng();
let value1 = Scalar::random(&mut rng);
let value2 = Scalar::random(&mut rng);
let (res, _) = execute_mock_mpc(|fabric| async move {
let party0_value = fabric.share_scalar(value1, PARTY0);
let public_value = fabric.allocate_scalar(value2);
let res1 = &party0_value - &public_value;
let res_open1 = res1.open_authenticated().await.unwrap();
let expected1 = value1 - value2;
let res2 = &public_value - &party0_value;
let res_open2 = res2.open_authenticated().await.unwrap();
let expected2 = value2 - value1;
(res_open1 == expected1, res_open2 == expected2)
})
.await;
assert!(res.0);
assert!(res.1)
}
#[tokio::test]
async fn test_sub_constant() {
let mut rng = thread_rng();
let value1 = Scalar::random(&mut rng);
let value2 = Scalar::random(&mut rng);
let (res, _) = execute_mock_mpc(|fabric| async move {
let party0_value = fabric.share_scalar(value1, PARTY0);
let res1 = &party0_value - value2;
let res_open1 = res1.open_authenticated().await.unwrap();
let expected1 = value1 - value2;
let res2 = value2 - &party0_value;
let res_open2 = res2.open_authenticated().await.unwrap();
let expected2 = value2 - value1;
(res_open1 == expected1, res_open2 == expected2)
})
.await;
assert!(res.0);
assert!(res.1)
}
#[tokio::test]
async fn test_xor_circuit() {
let (res, _) = execute_mock_mpc(|fabric| async move {
let a = &fabric.zero_authenticated();
let b = &fabric.zero_authenticated();
let res = a + b - Scalar::from(2u64) * a * b;
res.open_authenticated().await
})
.await;
assert_eq!(res.unwrap(), 0.into());
}
}