use itertools::Itertools;
use mpc_stark::{
algebra::{
authenticated_scalar::AuthenticatedScalarResult,
authenticated_stark_point::AuthenticatedStarkPointResult, scalar::Scalar,
stark_curve::StarkPoint,
},
random_point, PARTY0, PARTY1,
};
use rand::thread_rng;
use crate::{
helpers::{
assert_points_eq, assert_scalars_eq, await_result, await_result_batch,
share_plaintext_value, share_plaintext_values_batch,
},
IntegrationTest, IntegrationTestArgs,
};
fn test_inner_product(test_args: &IntegrationTestArgs) -> Result<(), String> {
let n = 100;
let fabric = &test_args.fabric;
let mut rng = thread_rng();
let my_vals = (0..n).map(|_| Scalar::random(&mut rng)).collect_vec();
let allocd_vals = my_vals
.iter()
.map(|val| fabric.allocate_scalar(*val))
.collect_vec();
let a_plaintext =
await_result_batch(&share_plaintext_values_batch(&allocd_vals, PARTY0, fabric));
let b_plaintext =
await_result_batch(&share_plaintext_values_batch(&allocd_vals, PARTY1, fabric));
let expected_res: Scalar = a_plaintext
.iter()
.zip(b_plaintext)
.map(|(a, b)| a * b)
.sum();
let a = my_vals
.iter()
.map(|val| fabric.share_scalar(*val, PARTY0))
.collect_vec();
let b = my_vals
.iter()
.map(|val| fabric.share_scalar(*val, PARTY1))
.collect_vec();
let res: AuthenticatedScalarResult = a.iter().zip(b.iter()).map(|(a, b)| a * b).sum();
let res_open = await_result(res.open_authenticated())
.map_err(|err| format!("error opening result: {err:?}"))?;
assert_scalars_eq(expected_res, res_open)
}
fn test_msm(test_args: &IntegrationTestArgs) -> Result<(), String> {
let n = 100;
let fabric = &test_args.fabric;
let mut rng = thread_rng();
let my_scalars = (0..n).map(|_| Scalar::random(&mut rng)).collect_vec();
let my_points = (0..n).map(|_| random_point()).collect_vec();
let allocd_scalars = my_scalars
.iter()
.map(|scalar| fabric.allocate_scalar(*scalar))
.collect_vec();
let allocd_points = my_points
.iter()
.map(|point| fabric.allocate_point(*point))
.collect_vec();
let plaintext_scalars = await_result_batch(&share_plaintext_values_batch(
&allocd_scalars,
PARTY0,
fabric,
));
let plaintext_points = await_result_batch(&share_plaintext_values_batch(
&allocd_points,
PARTY1,
fabric,
));
let expected_res = StarkPoint::msm(&plaintext_scalars, &plaintext_points);
let shared_scalars = my_scalars
.iter()
.map(|scalar| fabric.share_scalar(*scalar, PARTY0))
.collect_vec();
let shared_points = my_points
.iter()
.map(|point| fabric.share_point(*point, PARTY1))
.collect_vec();
let res = AuthenticatedStarkPointResult::msm(&shared_scalars, &shared_points);
let res_open = await_result(res.open_authenticated())
.map_err(|err| format!("error opening msm result: {err:?}"))?;
assert_points_eq(res_open, expected_res)
}
fn test_polynomial_eval(test_args: &IntegrationTestArgs) -> Result<(), String> {
let fabric = &test_args.fabric;
let mut rng = thread_rng();
let public_modifier = Scalar::random(&mut rng);
let public_modifier =
share_plaintext_value(fabric.allocate_scalar(public_modifier), PARTY0, fabric);
let fabric = &test_args.fabric;
let my_x = fabric.allocate_scalar(Scalar::random(&mut thread_rng()));
let x = fabric.exchange_value(my_x.clone()) + my_x;
let x_res = await_result(x.clone());
let my_coeffs = (0..3).map(|_| Scalar::random(&mut rng)).collect_vec();
let my_allocated_coeffs = my_coeffs
.iter()
.map(|coeff| fabric.allocate_scalar(*coeff))
.collect_vec();
let first_coeffs = await_result_batch(&share_plaintext_values_batch(
&my_allocated_coeffs,
PARTY0,
fabric,
))
.iter()
.map(|x| x + &public_modifier)
.map(await_result)
.collect_vec();
let second_coeffs = await_result_batch(&share_plaintext_values_batch(
&my_allocated_coeffs,
PARTY1,
fabric,
))
.iter()
.map(|x| x + &public_modifier)
.map(await_result)
.collect_vec();
let expected_res = x_res
* (first_coeffs[0]
+ x_res
* (first_coeffs[1]
+ x_res
* (first_coeffs[2]
+ x_res
* (second_coeffs[0]
+ x_res * (second_coeffs[1] + x_res * second_coeffs[2])))));
let first_shared_coeffs = my_coeffs
.iter()
.map(|coeff| fabric.share_scalar(*coeff, PARTY0))
.map(|coeff| coeff + &public_modifier)
.collect_vec();
let second_shared_coeffs = my_coeffs
.iter()
.map(|coeff| fabric.share_scalar(*coeff, PARTY1))
.map(|coeff| coeff + &public_modifier)
.collect_vec();
let res = &x
* (&first_shared_coeffs[0]
+ &x * (&first_shared_coeffs[1]
+ &x * (&first_shared_coeffs[2]
+ &x * (&second_shared_coeffs[0]
+ &x * (&second_shared_coeffs[1] + &x * &second_shared_coeffs[2])))));
let res = await_result(res.open_authenticated())
.map_err(|err| format!("error opening polynomial eval result: {err:?}"))?;
assert_scalars_eq(res, expected_res)
}
inventory::submit!(IntegrationTest {
name: "circuits::test_inner_product",
test_fn: test_inner_product
});
inventory::submit!(IntegrationTest {
name: "circuits::test_msm",
test_fn: test_msm
});
inventory::submit!(IntegrationTest {
name: "circuits::test_polynomial_eval",
test_fn: test_polynomial_eval
});