use std::fmt::Debug;
use futures::{future::join_all, Future};
use itertools::Itertools;
use mpc_stark::{
algebra::{
authenticated_scalar::AuthenticatedScalarResult,
authenticated_stark_point::AuthenticatedStarkPointResult, mpc_scalar::MpcScalarResult,
mpc_stark_point::MpcStarkPointResult, scalar::Scalar, stark_curve::StarkPoint,
},
beaver::SharedValueSource,
network::{NetworkPayload, PartyId},
{MpcFabric, ResultHandle, ResultValue},
};
use tokio::runtime::Handle;
use crate::IntegrationTestArgs;
pub(crate) fn assert_scalars_eq(a: Scalar, b: Scalar) -> Result<(), String> {
if a == b {
Ok(())
} else {
Err(format!("{a:?} != {b:?}"))
}
}
pub(crate) fn assert_scalar_batches_eq(a: Vec<Scalar>, b: Vec<Scalar>) -> Result<(), String> {
if a.len() != b.len() {
return Err(format!("Lengths differ: {a:?} != {b:?}"));
}
for (a, b) in a.into_iter().zip(b.into_iter()) {
assert_scalars_eq(a, b)?;
}
Ok(())
}
pub(crate) fn assert_points_eq(a: StarkPoint, b: StarkPoint) -> Result<(), String> {
if a == b {
Ok(())
} else {
Err(format!("{a:?} != {b:?}"))
}
}
pub(crate) fn assert_point_batches_eq(
a: Vec<StarkPoint>,
b: Vec<StarkPoint>,
) -> Result<(), String> {
if a.len() != b.len() {
return Err(format!("Lengths differ: {a:?} != {b:?}"));
}
for (a, b) in a.into_iter().zip(b.into_iter()) {
assert_points_eq(a, b)?;
}
Ok(())
}
pub(crate) fn assert_err<T, E>(res: Result<T, E>) -> Result<(), String> {
if res.is_err() {
Ok(())
} else {
Err("Expected error, got Ok".to_string())
}
}
pub(crate) fn await_result<R, T: Future<Output = R>>(res: T) -> R {
Handle::current().block_on(res)
}
pub(crate) fn await_result_batch<R, T: Future<Output = R> + Clone>(res: &[T]) -> Vec<R> {
res.iter()
.map(|res| await_result(res.clone()))
.collect_vec()
}
pub(crate) fn await_result_with_error<R, E: Debug, T: Future<Output = Result<R, E>>>(
res: T,
) -> Result<R, String> {
Handle::current()
.block_on(res)
.map_err(|err| format!("Error awaiting result: {:?}", err))
}
pub(crate) fn await_batch_result_with_error<R, E, T>(res: Vec<T>) -> Result<Vec<R>, String>
where
E: Debug,
T: Future<Output = Result<R, E>>,
{
Handle::current()
.block_on(join_all(res))
.into_iter()
.collect::<Result<Vec<_>, _>>()
.map_err(|err| format!("Error awaiting result: {:?}", err))
}
pub(crate) fn share_scalar(
value: Scalar,
sender: PartyId,
test_args: &IntegrationTestArgs,
) -> MpcScalarResult {
let authenticated_value = test_args.fabric.share_scalar(value, sender);
authenticated_value.mpc_share()
}
pub(crate) fn share_scalar_batch(
values: Vec<Scalar>,
sender: PartyId,
test_args: &IntegrationTestArgs,
) -> Vec<MpcScalarResult> {
test_args
.fabric
.batch_share_scalar(values, sender)
.iter()
.map(|v| v.mpc_share())
.collect_vec()
}
pub(crate) fn share_point(
value: StarkPoint,
sender: PartyId,
test_args: &IntegrationTestArgs,
) -> MpcStarkPointResult {
let authenticated_point = share_authenticated_point(value, sender, test_args);
authenticated_point.mpc_share()
}
pub(crate) fn share_point_batch(
values: Vec<StarkPoint>,
sender: PartyId,
test_args: &IntegrationTestArgs,
) -> Vec<MpcStarkPointResult> {
values
.into_iter()
.map(|point| share_point(point, sender, test_args))
.collect_vec()
}
pub(crate) fn share_authenticated_scalar(
value: Scalar,
sender: PartyId,
test_args: &IntegrationTestArgs,
) -> AuthenticatedScalarResult {
test_args.fabric.share_scalar(value, sender)
}
pub(crate) fn share_authenticated_scalar_batch(
values: Vec<Scalar>,
sender: PartyId,
test_args: &IntegrationTestArgs,
) -> Vec<AuthenticatedScalarResult> {
test_args.fabric.batch_share_scalar(values, sender)
}
pub(crate) fn share_authenticated_point(
value: StarkPoint,
sender: PartyId,
test_args: &IntegrationTestArgs,
) -> AuthenticatedStarkPointResult {
test_args.fabric.share_point(value, sender)
}
pub(crate) fn share_authenticated_point_batch(
values: Vec<StarkPoint>,
sender: PartyId,
test_args: &IntegrationTestArgs,
) -> Vec<AuthenticatedStarkPointResult> {
test_args.fabric.batch_share_point(values, sender)
}
pub(crate) fn share_plaintext_value<T: From<ResultValue> + Into<NetworkPayload>>(
value: ResultHandle<T>,
sender: PartyId,
fabric: &MpcFabric,
) -> ResultHandle<T> {
if fabric.party_id() == sender {
fabric.send_value(value)
} else {
fabric.receive_value()
}
}
pub(crate) fn share_plaintext_values_batch<T: From<ResultValue> + Into<NetworkPayload> + Clone>(
values: &[ResultHandle<T>],
sender: PartyId,
fabric: &MpcFabric,
) -> Vec<ResultHandle<T>> {
values
.iter()
.map(|v| share_plaintext_value(v.clone(), sender, fabric))
.collect_vec()
}
#[derive(Clone, Debug)]
pub(crate) struct PartyIDBeaverSource {
party_id: u64,
}
impl PartyIDBeaverSource {
pub fn new(party_id: u64) -> Self {
Self { party_id }
}
}
impl SharedValueSource for PartyIDBeaverSource {
fn next_shared_bit(&mut self) -> Scalar {
assert!(self.party_id == 0 || self.party_id == 1);
Scalar::from(self.party_id)
}
fn next_triplet(&mut self) -> (Scalar, Scalar, Scalar) {
if self.party_id == 0 {
(Scalar::from(1u64), Scalar::from(3u64), Scalar::from(2u64))
} else {
(Scalar::from(1u64), Scalar::from(0u64), Scalar::from(4u64))
}
}
fn next_shared_inverse_pair(&mut self) -> (Scalar, Scalar) {
(Scalar::from(1), Scalar::from(1))
}
fn next_shared_value(&mut self) -> Scalar {
Scalar::from(self.party_id)
}
}